JoJosmin commited on
Commit
cc63231
1 Parent(s): cdfb401

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +211 -0
app.py ADDED
@@ -0,0 +1,211 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import open_clip
3
+ import torch
4
+ import requests
5
+ from PIL import Image
6
+ from io import BytesIO
7
+ import time
8
+ import numpy as np
9
+ from ultralytics import YOLO
10
+ import chromadb
11
+ from transformers import pipeline
12
+ from sklearn.metrics.pairwise import cosine_similarity
13
+
14
+ # Load segmentation model
15
+ segmenter = pipeline(model="mattmdjaga/segformer_b2_clothes")
16
+
17
+ # Load CLIP model and tokenizer
18
+ @st.cache_resource
19
+ def load_clip_model():
20
+ model, preprocess_train, preprocess_val = open_clip.create_model_and_transforms('hf-hub:Marqo/marqo-fashionSigLIP')
21
+ tokenizer = open_clip.get_tokenizer('hf-hub:Marqo/marqo-fashionSigLIP')
22
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
23
+ model.to(device)
24
+ return model, preprocess_val, tokenizer, device
25
+
26
+ clip_model, preprocess_val, tokenizer, device = load_clip_model()
27
+
28
+ # Load YOLOv8 model
29
+ @st.cache_resource
30
+ def load_yolo_model():
31
+ return YOLO("./best.pt")
32
+
33
+ yolo_model = load_yolo_model()
34
+
35
+ # Load chromaDB
36
+ client = chromadb.PersistentClient(path="./clothesDB_Test")
37
+ #collection = client.get_collection(name="clothes_items_ver3")
38
+ collection = client.get_collection(name="category_upper")
39
+
40
+
41
+ # Helper functions
42
+ def load_image_from_url(url, max_retries=3):
43
+ for attempt in range(max_retries):
44
+ try:
45
+ response = requests.get(url, timeout=10)
46
+ response.raise_for_status()
47
+ img = Image.open(BytesIO(response.content)).convert('RGB')
48
+ return img
49
+ except (requests.RequestException, Image.UnidentifiedImageError) as e:
50
+ if attempt < max_retries - 1:
51
+ time.sleep(1)
52
+ else:
53
+ return None
54
+
55
+ def get_image_embedding(image):
56
+ image_tensor = preprocess_val(image).unsqueeze(0).to(device)
57
+ with torch.no_grad():
58
+ image_features = clip_model.encode_image(image_tensor)
59
+ image_features /= image_features.norm(dim=-1, keepdim=True)
60
+ return image_features.cpu().numpy()
61
+
62
+ def segment_clothing(img, clothes=["Hat", "Upper-clothes", "Skirt", "Pants", "Dress", "Belt", "Left-shoe", "Right-shoe", "Scarf"]):
63
+ # Segment image
64
+ segments = segmenter(img)
65
+
66
+ # Create list of masks
67
+ mask_list = []
68
+ detected_categories = []
69
+ for s in segments:
70
+ if s['label'] in clothes:
71
+ mask_list.append(s['mask'])
72
+ detected_categories.append(s['label']) # Store detected categories
73
+
74
+ # Paste all masks on top of each other
75
+ final_mask = np.zeros_like(np.array(img)[:, :, 0]) # Initialize mask
76
+ for mask in mask_list:
77
+ current_mask = np.array(mask)
78
+ final_mask = np.maximum(final_mask, current_mask) # Use maximum to combine masks
79
+
80
+ # Convert final mask from np array to PIL image
81
+ final_mask = Image.fromarray(final_mask.astype(np.uint8) * 255) # Convert to binary mask
82
+
83
+ # Apply mask to original image
84
+ img_with_alpha = img.convert("RGBA") # Ensure the image has an alpha channel
85
+ img_with_alpha.putalpha(final_mask)
86
+
87
+ return img_with_alpha.convert("RGB"), final_mask, detected_categories # Return detected categories
88
+
89
+
90
+ #def find_similar_images(query_embedding, collection, top_k=5):
91
+ # all_embeddings = collection.get(include=['embeddings'])['embeddings']
92
+ # database_embeddings = np.array(all_embeddings)
93
+ # similarities = np.dot(database_embeddings, query_embedding.T).squeeze()
94
+ # top_indices = np.argsort(similarities)[::-1][:top_k]
95
+ #
96
+ # all_data = collection.get(include=['metadatas'])['metadatas']
97
+ # top_metadatas = [all_data[idx] for idx in top_indices]
98
+
99
+ # results = []
100
+ # for idx, metadata in enumerate(top_metadatas):
101
+ # results.append({
102
+ # 'info': metadata,
103
+ # 'similarity': similarities[top_indices[idx]]
104
+ # })
105
+ # return results
106
+
107
+ def find_similar_images(query_embedding, collection, top_k=5):
108
+ # 모든 임베딩을 가져옴
109
+ all_embeddings = collection.get(include=['embeddings'])['embeddings']
110
+ database_embeddings = np.array(all_embeddings)
111
+
112
+ # 유사도 계산
113
+ similarities = cosine_similarity(database_embeddings, query_embedding.reshape(1, -1)).squeeze()
114
+ top_indices = np.argsort(similarities)[::-1][:top_k]
115
+
116
+ # 메타데이터 가져옴
117
+ all_data = collection.get(include=['metadatas'])['metadatas']
118
+ top_metadatas = [all_data[idx] for idx in top_indices]
119
+
120
+ results = []
121
+ for idx, metadata in enumerate(top_metadatas):
122
+ # 이미지 URLs 필드가 쉼표로 구분된 문자열로 저장된 경우, 이를 리스트로 변환
123
+ image_urls = metadata['image_url'].split(',')
124
+ # 첫 번째 이미지를 대표 이미지로 사용
125
+ representative_image_url = image_urls[0] if image_urls else None
126
+
127
+ results.append({
128
+ 'info': metadata,
129
+ 'similarity': similarities[top_indices[idx]],
130
+ 'image_url': representative_image_url # 첫 번째 이미지 URL 사용
131
+ })
132
+ return results
133
+
134
+
135
+ # 세션 상태 초기화
136
+ if 'step' not in st.session_state:
137
+ st.session_state.step = 'input'
138
+ if 'query_image_url' not in st.session_state:
139
+ st.session_state.query_image_url = ''
140
+ if 'detections' not in st.session_state:
141
+ st.session_state.detections = []
142
+ if 'segmented_image' not in st.session_state: # Add segmented_image to session state
143
+ st.session_state.segmented_image = None
144
+ if 'selected_category' not in st.session_state:
145
+ st.session_state.selected_category = None
146
+
147
+ # Streamlit app
148
+ st.title("Advanced Fashion Search App")
149
+
150
+ # 단계별 처리
151
+ if st.session_state.step == 'input':
152
+ st.session_state.query_image_url = st.text_input("Enter image URL:", st.session_state.query_image_url)
153
+ if st.button("Detect Clothing"):
154
+ if st.session_state.query_image_url:
155
+ query_image = load_image_from_url(st.session_state.query_image_url)
156
+ if query_image is not None:
157
+ st.session_state.query_image = query_image
158
+ # Perform segmentation
159
+ segmented_image, final_mask, detected_categories = segment_clothing(query_image)
160
+ st.session_state.segmented_image = segmented_image # Store segmented image in session state
161
+ st.session_state.detections = detected_categories # Store detected categories
162
+ st.image(segmented_image, caption="Segmented Image", use_column_width=True)
163
+ if st.session_state.detections:
164
+ st.session_state.step = 'select_category'
165
+ else:
166
+ st.warning("No clothing items detected in the image.")
167
+ else:
168
+ st.error("Failed to load the image. Please try another URL.")
169
+ else:
170
+ st.warning("Please enter an image URL.")
171
+
172
+ elif st.session_state.step == 'select_category':
173
+ st.image(st.session_state.segmented_image, caption="Segmented Image with Detected Categories", use_column_width=True)
174
+ st.subheader("Detected Clothing Categories:")
175
+
176
+ if st.session_state.detections:
177
+ selected_category = st.selectbox("Select a category to search:", st.session_state.detections)
178
+ if st.button("Search Similar Items"):
179
+ st.session_state.selected_category = selected_category
180
+ st.session_state.step = 'show_results'
181
+ else:
182
+ st.warning("No categories detected.")
183
+
184
+ elif st.session_state.step == 'show_results':
185
+ original_image = st.session_state.query_image.convert("RGB") # Convert to RGB before displaying
186
+ st.image(original_image, caption="Original Image", use_column_width=True)
187
+
188
+ # Get the embedding of the segmented image
189
+ query_embedding = get_image_embedding(st.session_state.segmented_image) # Use the segmented image from session state
190
+ similar_images = find_similar_images(query_embedding, collection)
191
+
192
+ st.subheader("Similar Items:")
193
+ for img in similar_images:
194
+ col1, col2 = st.columns(2)
195
+ with col1:
196
+ st.image(img['image_url'], use_column_width=True)
197
+ with col2:
198
+ st.write(f"Name: {img['info']['name']}")
199
+ st.write(f"Brand: {img['info']['brand']}")
200
+ category = img['info'].get('category')
201
+ if category:
202
+ st.write(f"Category: {category}")
203
+ st.write(f"Price: {img['info']['price']}")
204
+ st.write(f"Discount: {img['info']['discount']}%")
205
+ st.write(f"Similarity: {img['similarity']:.2f}")
206
+
207
+ if st.button("Start New Search"):
208
+ st.session_state.step = 'input'
209
+ st.session_state.query_image_url = ''
210
+ st.session_state.detections = []
211
+ st.session_state.segmented_image = None # Reset segmented_image