Browse files
@@ -0,0 +1,211 @@
1 |
import streamlit as st
2 |
import open_clip
3 |
import torch
4 |
import requests
5 |
from PIL import Image
6 |
from io import BytesIO
7 |
import time
8 |
import numpy as np
9 |
from ultralytics import YOLO
10 |
import chromadb
11 |
from transformers import pipeline
12 |
from sklearn.metrics.pairwise import cosine_similarity
13 |
14 |
# Load segmentation model
15 |
segmenter = pipeline(model="mattmdjaga/segformer_b2_clothes")
16 |
17 |
# Load CLIP model and tokenizer
18 |
19 |
def load_clip_model():
20 |
model, preprocess_train, preprocess_val = open_clip.create_model_and_transforms('hf-hub:Marqo/marqo-fashionSigLIP')
21 |
tokenizer = open_clip.get_tokenizer('hf-hub:Marqo/marqo-fashionSigLIP')
22 |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
23 |
24 |
return model, preprocess_val, tokenizer, device
25 |
26 |
clip_model, preprocess_val, tokenizer, device = load_clip_model()
27 |
28 |
# Load YOLOv8 model
29 |
30 |
def load_yolo_model():
31 |
return YOLO("./")
32 |
33 |
yolo_model = load_yolo_model()
34 |
35 |
# Load chromaDB
36 |
client = chromadb.PersistentClient(path="./clothesDB_Test")
37 |
#collection = client.get_collection(name="clothes_items_ver3")
38 |
collection = client.get_collection(name="category_upper")
39 |
40 |
41 |
# Helper functions
42 |
def load_image_from_url(url, max_retries=3):
43 |
for attempt in range(max_retries):
44 |
45 |
response = requests.get(url, timeout=10)
46 |
47 |
img ='RGB')
48 |
return img
49 |
except (requests.RequestException, Image.UnidentifiedImageError) as e:
50 |
if attempt < max_retries - 1:
51 |
52 |
53 |
return None
54 |
55 |
def get_image_embedding(image):
56 |
image_tensor = preprocess_val(image).unsqueeze(0).to(device)
57 |
with torch.no_grad():
58 |
image_features = clip_model.encode_image(image_tensor)
59 |
image_features /= image_features.norm(dim=-1, keepdim=True)
60 |
return image_features.cpu().numpy()
61 |
62 |
def segment_clothing(img, clothes=["Hat", "Upper-clothes", "Skirt", "Pants", "Dress", "Belt", "Left-shoe", "Right-shoe", "Scarf"]):
63 |
# Segment image
64 |
segments = segmenter(img)
65 |
66 |
# Create list of masks
67 |
mask_list = []
68 |
detected_categories = []
69 |
for s in segments:
70 |
if s['label'] in clothes:
71 |
72 |
detected_categories.append(s['label']) # Store detected categories
73 |
74 |
# Paste all masks on top of each other
75 |
final_mask = np.zeros_like(np.array(img)[:, :, 0]) # Initialize mask
76 |
for mask in mask_list:
77 |
current_mask = np.array(mask)
78 |
final_mask = np.maximum(final_mask, current_mask) # Use maximum to combine masks
79 |
80 |
# Convert final mask from np array to PIL image
81 |
final_mask = Image.fromarray(final_mask.astype(np.uint8) * 255) # Convert to binary mask
82 |
83 |
# Apply mask to original image
84 |
img_with_alpha = img.convert("RGBA") # Ensure the image has an alpha channel
85 |
86 |
87 |
return img_with_alpha.convert("RGB"), final_mask, detected_categories # Return detected categories
88 |
89 |
90 |
#def find_similar_images(query_embedding, collection, top_k=5):
91 |
# all_embeddings = collection.get(include=['embeddings'])['embeddings']
92 |
# database_embeddings = np.array(all_embeddings)
93 |
# similarities =, query_embedding.T).squeeze()
94 |
# top_indices = np.argsort(similarities)[::-1][:top_k]
95 |
96 |
# all_data = collection.get(include=['metadatas'])['metadatas']
97 |
# top_metadatas = [all_data[idx] for idx in top_indices]
98 |
99 |
# results = []
100 |
# for idx, metadata in enumerate(top_metadatas):
101 |
# results.append({
102 |
# 'info': metadata,
103 |
# 'similarity': similarities[top_indices[idx]]
104 |
# })
105 |
# return results
106 |
107 |
def find_similar_images(query_embedding, collection, top_k=5):
108 |
# 모든 임베딩을 가져옴
109 |
all_embeddings = collection.get(include=['embeddings'])['embeddings']
110 |
database_embeddings = np.array(all_embeddings)
111 |
112 |
# 유사도 계산
113 |
similarities = cosine_similarity(database_embeddings, query_embedding.reshape(1, -1)).squeeze()
114 |
top_indices = np.argsort(similarities)[::-1][:top_k]
115 |
116 |
# 메타데이터 가져옴
117 |
all_data = collection.get(include=['metadatas'])['metadatas']
118 |
top_metadatas = [all_data[idx] for idx in top_indices]
119 |
120 |
results = []
121 |
for idx, metadata in enumerate(top_metadatas):
122 |
# 이미지 URLs 필드가 쉼표로 구분된 문자열로 저장된 경우, 이를 리스트로 변환
123 |
image_urls = metadata['image_url'].split(',')
124 |
# 첫 번째 이미지를 대표 이미지로 사용
125 |
representative_image_url = image_urls[0] if image_urls else None
126 |
127 |
128 |
'info': metadata,
129 |
'similarity': similarities[top_indices[idx]],
130 |
'image_url': representative_image_url # 첫 번째 이미지 URL 사용
131 |
132 |
return results
133 |
134 |
135 |
# 세션 상태 초기화
136 |
if 'step' not in st.session_state:
137 |
st.session_state.step = 'input'
138 |
if 'query_image_url' not in st.session_state:
139 |
st.session_state.query_image_url = ''
140 |
if 'detections' not in st.session_state:
141 |
st.session_state.detections = []
142 |
if 'segmented_image' not in st.session_state: # Add segmented_image to session state
143 |
st.session_state.segmented_image = None
144 |
if 'selected_category' not in st.session_state:
145 |
st.session_state.selected_category = None
146 |
147 |
# Streamlit app
148 |
st.title("Advanced Fashion Search App")
149 |
150 |
# 단계별 처리
151 |
if st.session_state.step == 'input':
152 |
st.session_state.query_image_url = st.text_input("Enter image URL:", st.session_state.query_image_url)
153 |
if st.button("Detect Clothing"):
154 |
if st.session_state.query_image_url:
155 |
query_image = load_image_from_url(st.session_state.query_image_url)
156 |
if query_image is not None:
157 |
st.session_state.query_image = query_image
158 |
# Perform segmentation
159 |
segmented_image, final_mask, detected_categories = segment_clothing(query_image)
160 |
st.session_state.segmented_image = segmented_image # Store segmented image in session state
161 |
st.session_state.detections = detected_categories # Store detected categories
162 |
st.image(segmented_image, caption="Segmented Image", use_column_width=True)
163 |
if st.session_state.detections:
164 |
st.session_state.step = 'select_category'
165 |
166 |
st.warning("No clothing items detected in the image.")
167 |
168 |
st.error("Failed to load the image. Please try another URL.")
169 |
170 |
st.warning("Please enter an image URL.")
171 |
172 |
elif st.session_state.step == 'select_category':
173 |
st.image(st.session_state.segmented_image, caption="Segmented Image with Detected Categories", use_column_width=True)
174 |
st.subheader("Detected Clothing Categories:")
175 |
176 |
if st.session_state.detections:
177 |
selected_category = st.selectbox("Select a category to search:", st.session_state.detections)
178 |
if st.button("Search Similar Items"):
179 |
st.session_state.selected_category = selected_category
180 |
st.session_state.step = 'show_results'
181 |
182 |
st.warning("No categories detected.")
183 |
184 |
elif st.session_state.step == 'show_results':
185 |
original_image = st.session_state.query_image.convert("RGB") # Convert to RGB before displaying
186 |
st.image(original_image, caption="Original Image", use_column_width=True)
187 |
188 |
# Get the embedding of the segmented image
189 |
query_embedding = get_image_embedding(st.session_state.segmented_image) # Use the segmented image from session state
190 |
similar_images = find_similar_images(query_embedding, collection)
191 |
192 |
st.subheader("Similar Items:")
193 |
for img in similar_images:
194 |
col1, col2 = st.columns(2)
195 |
with col1:
196 |
st.image(img['image_url'], use_column_width=True)
197 |
with col2:
198 |
st.write(f"Name: {img['info']['name']}")
199 |
st.write(f"Brand: {img['info']['brand']}")
200 |
category = img['info'].get('category')
201 |
if category:
202 |
st.write(f"Category: {category}")
203 |
st.write(f"Price: {img['info']['price']}")
204 |
st.write(f"Discount: {img['info']['discount']}%")
205 |
st.write(f"Similarity: {img['similarity']:.2f}")
206 |
207 |
if st.button("Start New Search"):
208 |
st.session_state.step = 'input'
209 |
st.session_state.query_image_url = ''
210 |
st.session_state.detections = []
211 |
st.session_state.segmented_image = None # Reset segmented_image