Spaces:
Sleeping
Sleeping
Create app.py
Browse files
app.py
ADDED
@@ -0,0 +1,211 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import streamlit as st
|
2 |
+
import open_clip
|
3 |
+
import torch
|
4 |
+
import requests
|
5 |
+
from PIL import Image
|
6 |
+
from io import BytesIO
|
7 |
+
import time
|
8 |
+
import numpy as np
|
9 |
+
from ultralytics import YOLO
|
10 |
+
import chromadb
|
11 |
+
from transformers import pipeline
|
12 |
+
from sklearn.metrics.pairwise import cosine_similarity
|
13 |
+
|
14 |
+
# Load segmentation model
|
15 |
+
segmenter = pipeline(model="mattmdjaga/segformer_b2_clothes")
|
16 |
+
|
17 |
+
# Load CLIP model and tokenizer
|
18 |
+
@st.cache_resource
|
19 |
+
def load_clip_model():
|
20 |
+
model, preprocess_train, preprocess_val = open_clip.create_model_and_transforms('hf-hub:Marqo/marqo-fashionSigLIP')
|
21 |
+
tokenizer = open_clip.get_tokenizer('hf-hub:Marqo/marqo-fashionSigLIP')
|
22 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
23 |
+
model.to(device)
|
24 |
+
return model, preprocess_val, tokenizer, device
|
25 |
+
|
26 |
+
clip_model, preprocess_val, tokenizer, device = load_clip_model()
|
27 |
+
|
28 |
+
# Load YOLOv8 model
|
29 |
+
@st.cache_resource
|
30 |
+
def load_yolo_model():
|
31 |
+
return YOLO("./best.pt")
|
32 |
+
|
33 |
+
yolo_model = load_yolo_model()
|
34 |
+
|
35 |
+
# Load chromaDB
|
36 |
+
client = chromadb.PersistentClient(path="./clothesDB_Test")
|
37 |
+
#collection = client.get_collection(name="clothes_items_ver3")
|
38 |
+
collection = client.get_collection(name="category_upper")
|
39 |
+
|
40 |
+
|
41 |
+
# Helper functions
|
42 |
+
def load_image_from_url(url, max_retries=3):
|
43 |
+
for attempt in range(max_retries):
|
44 |
+
try:
|
45 |
+
response = requests.get(url, timeout=10)
|
46 |
+
response.raise_for_status()
|
47 |
+
img = Image.open(BytesIO(response.content)).convert('RGB')
|
48 |
+
return img
|
49 |
+
except (requests.RequestException, Image.UnidentifiedImageError) as e:
|
50 |
+
if attempt < max_retries - 1:
|
51 |
+
time.sleep(1)
|
52 |
+
else:
|
53 |
+
return None
|
54 |
+
|
55 |
+
def get_image_embedding(image):
|
56 |
+
image_tensor = preprocess_val(image).unsqueeze(0).to(device)
|
57 |
+
with torch.no_grad():
|
58 |
+
image_features = clip_model.encode_image(image_tensor)
|
59 |
+
image_features /= image_features.norm(dim=-1, keepdim=True)
|
60 |
+
return image_features.cpu().numpy()
|
61 |
+
|
62 |
+
def segment_clothing(img, clothes=["Hat", "Upper-clothes", "Skirt", "Pants", "Dress", "Belt", "Left-shoe", "Right-shoe", "Scarf"]):
|
63 |
+
# Segment image
|
64 |
+
segments = segmenter(img)
|
65 |
+
|
66 |
+
# Create list of masks
|
67 |
+
mask_list = []
|
68 |
+
detected_categories = []
|
69 |
+
for s in segments:
|
70 |
+
if s['label'] in clothes:
|
71 |
+
mask_list.append(s['mask'])
|
72 |
+
detected_categories.append(s['label']) # Store detected categories
|
73 |
+
|
74 |
+
# Paste all masks on top of each other
|
75 |
+
final_mask = np.zeros_like(np.array(img)[:, :, 0]) # Initialize mask
|
76 |
+
for mask in mask_list:
|
77 |
+
current_mask = np.array(mask)
|
78 |
+
final_mask = np.maximum(final_mask, current_mask) # Use maximum to combine masks
|
79 |
+
|
80 |
+
# Convert final mask from np array to PIL image
|
81 |
+
final_mask = Image.fromarray(final_mask.astype(np.uint8) * 255) # Convert to binary mask
|
82 |
+
|
83 |
+
# Apply mask to original image
|
84 |
+
img_with_alpha = img.convert("RGBA") # Ensure the image has an alpha channel
|
85 |
+
img_with_alpha.putalpha(final_mask)
|
86 |
+
|
87 |
+
return img_with_alpha.convert("RGB"), final_mask, detected_categories # Return detected categories
|
88 |
+
|
89 |
+
|
90 |
+
#def find_similar_images(query_embedding, collection, top_k=5):
|
91 |
+
# all_embeddings = collection.get(include=['embeddings'])['embeddings']
|
92 |
+
# database_embeddings = np.array(all_embeddings)
|
93 |
+
# similarities = np.dot(database_embeddings, query_embedding.T).squeeze()
|
94 |
+
# top_indices = np.argsort(similarities)[::-1][:top_k]
|
95 |
+
#
|
96 |
+
# all_data = collection.get(include=['metadatas'])['metadatas']
|
97 |
+
# top_metadatas = [all_data[idx] for idx in top_indices]
|
98 |
+
|
99 |
+
# results = []
|
100 |
+
# for idx, metadata in enumerate(top_metadatas):
|
101 |
+
# results.append({
|
102 |
+
# 'info': metadata,
|
103 |
+
# 'similarity': similarities[top_indices[idx]]
|
104 |
+
# })
|
105 |
+
# return results
|
106 |
+
|
107 |
+
def find_similar_images(query_embedding, collection, top_k=5):
|
108 |
+
# 모든 임베딩을 가져옴
|
109 |
+
all_embeddings = collection.get(include=['embeddings'])['embeddings']
|
110 |
+
database_embeddings = np.array(all_embeddings)
|
111 |
+
|
112 |
+
# 유사도 계산
|
113 |
+
similarities = cosine_similarity(database_embeddings, query_embedding.reshape(1, -1)).squeeze()
|
114 |
+
top_indices = np.argsort(similarities)[::-1][:top_k]
|
115 |
+
|
116 |
+
# 메타데이터 가져옴
|
117 |
+
all_data = collection.get(include=['metadatas'])['metadatas']
|
118 |
+
top_metadatas = [all_data[idx] for idx in top_indices]
|
119 |
+
|
120 |
+
results = []
|
121 |
+
for idx, metadata in enumerate(top_metadatas):
|
122 |
+
# 이미지 URLs 필드가 쉼표로 구분된 문자열로 저장된 경우, 이를 리스트로 변환
|
123 |
+
image_urls = metadata['image_url'].split(',')
|
124 |
+
# 첫 번째 이미지를 대표 이미지로 사용
|
125 |
+
representative_image_url = image_urls[0] if image_urls else None
|
126 |
+
|
127 |
+
results.append({
|
128 |
+
'info': metadata,
|
129 |
+
'similarity': similarities[top_indices[idx]],
|
130 |
+
'image_url': representative_image_url # 첫 번째 이미지 URL 사용
|
131 |
+
})
|
132 |
+
return results
|
133 |
+
|
134 |
+
|
135 |
+
# 세션 상태 초기화
|
136 |
+
if 'step' not in st.session_state:
|
137 |
+
st.session_state.step = 'input'
|
138 |
+
if 'query_image_url' not in st.session_state:
|
139 |
+
st.session_state.query_image_url = ''
|
140 |
+
if 'detections' not in st.session_state:
|
141 |
+
st.session_state.detections = []
|
142 |
+
if 'segmented_image' not in st.session_state: # Add segmented_image to session state
|
143 |
+
st.session_state.segmented_image = None
|
144 |
+
if 'selected_category' not in st.session_state:
|
145 |
+
st.session_state.selected_category = None
|
146 |
+
|
147 |
+
# Streamlit app
|
148 |
+
st.title("Advanced Fashion Search App")
|
149 |
+
|
150 |
+
# 단계별 처리
|
151 |
+
if st.session_state.step == 'input':
|
152 |
+
st.session_state.query_image_url = st.text_input("Enter image URL:", st.session_state.query_image_url)
|
153 |
+
if st.button("Detect Clothing"):
|
154 |
+
if st.session_state.query_image_url:
|
155 |
+
query_image = load_image_from_url(st.session_state.query_image_url)
|
156 |
+
if query_image is not None:
|
157 |
+
st.session_state.query_image = query_image
|
158 |
+
# Perform segmentation
|
159 |
+
segmented_image, final_mask, detected_categories = segment_clothing(query_image)
|
160 |
+
st.session_state.segmented_image = segmented_image # Store segmented image in session state
|
161 |
+
st.session_state.detections = detected_categories # Store detected categories
|
162 |
+
st.image(segmented_image, caption="Segmented Image", use_column_width=True)
|
163 |
+
if st.session_state.detections:
|
164 |
+
st.session_state.step = 'select_category'
|
165 |
+
else:
|
166 |
+
st.warning("No clothing items detected in the image.")
|
167 |
+
else:
|
168 |
+
st.error("Failed to load the image. Please try another URL.")
|
169 |
+
else:
|
170 |
+
st.warning("Please enter an image URL.")
|
171 |
+
|
172 |
+
elif st.session_state.step == 'select_category':
|
173 |
+
st.image(st.session_state.segmented_image, caption="Segmented Image with Detected Categories", use_column_width=True)
|
174 |
+
st.subheader("Detected Clothing Categories:")
|
175 |
+
|
176 |
+
if st.session_state.detections:
|
177 |
+
selected_category = st.selectbox("Select a category to search:", st.session_state.detections)
|
178 |
+
if st.button("Search Similar Items"):
|
179 |
+
st.session_state.selected_category = selected_category
|
180 |
+
st.session_state.step = 'show_results'
|
181 |
+
else:
|
182 |
+
st.warning("No categories detected.")
|
183 |
+
|
184 |
+
elif st.session_state.step == 'show_results':
|
185 |
+
original_image = st.session_state.query_image.convert("RGB") # Convert to RGB before displaying
|
186 |
+
st.image(original_image, caption="Original Image", use_column_width=True)
|
187 |
+
|
188 |
+
# Get the embedding of the segmented image
|
189 |
+
query_embedding = get_image_embedding(st.session_state.segmented_image) # Use the segmented image from session state
|
190 |
+
similar_images = find_similar_images(query_embedding, collection)
|
191 |
+
|
192 |
+
st.subheader("Similar Items:")
|
193 |
+
for img in similar_images:
|
194 |
+
col1, col2 = st.columns(2)
|
195 |
+
with col1:
|
196 |
+
st.image(img['image_url'], use_column_width=True)
|
197 |
+
with col2:
|
198 |
+
st.write(f"Name: {img['info']['name']}")
|
199 |
+
st.write(f"Brand: {img['info']['brand']}")
|
200 |
+
category = img['info'].get('category')
|
201 |
+
if category:
|
202 |
+
st.write(f"Category: {category}")
|
203 |
+
st.write(f"Price: {img['info']['price']}")
|
204 |
+
st.write(f"Discount: {img['info']['discount']}%")
|
205 |
+
st.write(f"Similarity: {img['similarity']:.2f}")
|
206 |
+
|
207 |
+
if st.button("Start New Search"):
|
208 |
+
st.session_state.step = 'input'
|
209 |
+
st.session_state.query_image_url = ''
|
210 |
+
st.session_state.detections = []
|
211 |
+
st.session_state.segmented_image = None # Reset segmented_image
|