Spaces:
Runtime error
Runtime error
Update app.py
Browse files
app.py
CHANGED
@@ -7,9 +7,9 @@ from io import BytesIO
|
|
7 |
import time
|
8 |
import json
|
9 |
import numpy as np
|
10 |
-
from ultralytics import YOLO
|
11 |
import cv2
|
12 |
import chromadb
|
|
|
13 |
|
14 |
# Load CLIP model and tokenizer
|
15 |
@st.cache_resource
|
@@ -22,12 +22,17 @@ def load_clip_model():
|
|
22 |
|
23 |
clip_model, preprocess_val, tokenizer, device = load_clip_model()
|
24 |
|
25 |
-
# Load
|
26 |
@st.cache_resource
|
27 |
-
def
|
28 |
-
|
|
|
|
|
29 |
|
30 |
-
|
|
|
|
|
|
|
31 |
|
32 |
# Helper functions
|
33 |
def load_image_from_url(url, max_retries=3):
|
@@ -42,6 +47,7 @@ def load_image_from_url(url, max_retries=3):
|
|
42 |
time.sleep(1)
|
43 |
else:
|
44 |
return None
|
|
|
45 |
#Load chromaDB
|
46 |
client = chromadb.PersistentClient(path="./clothesDB")
|
47 |
collection = client.get_collection(name="fashion_items_ver2")
|
@@ -85,27 +91,32 @@ def find_similar_images(query_embedding, collection, top_k=5):
|
|
85 |
})
|
86 |
return results
|
87 |
|
88 |
-
|
89 |
-
|
90 |
def detect_clothing(image):
|
91 |
-
|
92 |
-
|
|
|
|
|
|
|
|
|
93 |
categories = []
|
94 |
-
for
|
95 |
-
|
96 |
-
category =
|
97 |
-
if category in
|
98 |
categories.append({
|
99 |
'category': category,
|
100 |
-
'bbox':
|
101 |
-
'confidence':
|
102 |
})
|
103 |
return categories
|
104 |
|
105 |
def crop_image(image, bbox):
|
106 |
return image.crop((bbox[0], bbox[1], bbox[2], bbox[3]))
|
107 |
|
108 |
-
#
|
|
|
|
|
|
|
109 |
if 'step' not in st.session_state:
|
110 |
st.session_state.step = 'input'
|
111 |
if 'query_image_url' not in st.session_state:
|
@@ -115,10 +126,7 @@ if 'detections' not in st.session_state:
|
|
115 |
if 'selected_category' not in st.session_state:
|
116 |
st.session_state.selected_category = None
|
117 |
|
118 |
-
#
|
119 |
-
st.title("Advanced Fashion Search App")
|
120 |
-
|
121 |
-
# 단계별 처리
|
122 |
if st.session_state.step == 'input':
|
123 |
st.session_state.query_image_url = st.text_input("Enter image URL:", st.session_state.query_image_url)
|
124 |
if st.button("Detect Clothing"):
|
@@ -136,7 +144,6 @@ if st.session_state.step == 'input':
|
|
136 |
else:
|
137 |
st.warning("Please enter an image URL.")
|
138 |
|
139 |
-
# Update the 'select_category' step
|
140 |
elif st.session_state.step == 'select_category':
|
141 |
st.image(st.session_state.query_image, caption="Query Image", use_column_width=True)
|
142 |
st.subheader("Detected Clothing Items:")
|
@@ -184,23 +191,22 @@ elif st.session_state.step == 'show_results':
|
|
184 |
st.session_state.detections = []
|
185 |
st.session_state.selected_category = None
|
186 |
|
187 |
-
|
188 |
-
|
189 |
-
|
190 |
-
|
191 |
-
|
192 |
-
|
193 |
-
|
194 |
-
|
195 |
-
|
196 |
-
|
197 |
-
|
198 |
-
|
199 |
-
|
200 |
-
|
201 |
-
|
202 |
-
|
203 |
-
|
204 |
-
|
205 |
-
|
206 |
-
st.warning("Please enter a search text.")
|
|
|
7 |
import time
|
8 |
import json
|
9 |
import numpy as np
|
|
|
10 |
import cv2
|
11 |
import chromadb
|
12 |
+
from transformers import YolosImageProcessor, YolosForObjectDetection
|
13 |
|
14 |
# Load CLIP model and tokenizer
|
15 |
@st.cache_resource
|
|
|
22 |
|
23 |
clip_model, preprocess_val, tokenizer, device = load_clip_model()
|
24 |
|
25 |
+
# Load YOLOS model
|
26 |
@st.cache_resource
|
27 |
+
def load_yolos_model():
|
28 |
+
processor = YolosImageProcessor.from_pretrained("valentinafeve/yolos-fashionpedia")
|
29 |
+
model = YolosForObjectDetection.from_pretrained("valentinafeve/yolos-fashionpedia")
|
30 |
+
return processor, model
|
31 |
|
32 |
+
yolos_processor, yolos_model = load_yolos_model()
|
33 |
+
|
34 |
+
# Define the categories
|
35 |
+
CATS = ['shirt, blouse', 'top, t-shirt, sweatshirt', 'sweater', 'cardigan', 'jacket', 'vest', 'pants', 'shorts', 'skirt', 'coat', 'dress', 'jumpsuit', 'cape', 'glasses', 'hat', 'headband, head covering, hair accessory', 'tie', 'glove', 'watch', 'belt', 'leg warmer', 'tights, stockings', 'sock', 'shoe', 'bag, wallet', 'scarf', 'umbrella', 'hood', 'collar', 'lapel', 'epaulette', 'sleeve', 'pocket', 'neckline', 'buckle', 'zipper', 'applique', 'bead', 'bow', 'flower', 'fringe', 'ribbon', 'rivet', 'ruffle', 'sequin', 'tassel']
|
36 |
|
37 |
# Helper functions
|
38 |
def load_image_from_url(url, max_retries=3):
|
|
|
47 |
time.sleep(1)
|
48 |
else:
|
49 |
return None
|
50 |
+
|
51 |
#Load chromaDB
|
52 |
client = chromadb.PersistentClient(path="./clothesDB")
|
53 |
collection = client.get_collection(name="fashion_items_ver2")
|
|
|
91 |
})
|
92 |
return results
|
93 |
|
|
|
|
|
94 |
def detect_clothing(image):
|
95 |
+
inputs = yolos_processor(images=image, return_tensors="pt")
|
96 |
+
outputs = yolos_model(**inputs)
|
97 |
+
|
98 |
+
target_sizes = torch.tensor([image.size[::-1]])
|
99 |
+
results = yolos_processor.post_process_object_detection(outputs, threshold=0.1, target_sizes=target_sizes)[0]
|
100 |
+
|
101 |
categories = []
|
102 |
+
for score, label, box in zip(results["scores"], results["labels"], results["boxes"]):
|
103 |
+
box = [int(i) for i in box.tolist()]
|
104 |
+
category = yolos_model.config.id2label[label.item()]
|
105 |
+
if category in CATS:
|
106 |
categories.append({
|
107 |
'category': category,
|
108 |
+
'bbox': box,
|
109 |
+
'confidence': score.item()
|
110 |
})
|
111 |
return categories
|
112 |
|
113 |
def crop_image(image, bbox):
|
114 |
return image.crop((bbox[0], bbox[1], bbox[2], bbox[3]))
|
115 |
|
116 |
+
# Streamlit app
|
117 |
+
st.title("Advanced Fashion Search App")
|
118 |
+
|
119 |
+
# Initialize session state
|
120 |
if 'step' not in st.session_state:
|
121 |
st.session_state.step = 'input'
|
122 |
if 'query_image_url' not in st.session_state:
|
|
|
126 |
if 'selected_category' not in st.session_state:
|
127 |
st.session_state.selected_category = None
|
128 |
|
129 |
+
# Step-by-step processing
|
|
|
|
|
|
|
130 |
if st.session_state.step == 'input':
|
131 |
st.session_state.query_image_url = st.text_input("Enter image URL:", st.session_state.query_image_url)
|
132 |
if st.button("Detect Clothing"):
|
|
|
144 |
else:
|
145 |
st.warning("Please enter an image URL.")
|
146 |
|
|
|
147 |
elif st.session_state.step == 'select_category':
|
148 |
st.image(st.session_state.query_image, caption="Query Image", use_column_width=True)
|
149 |
st.subheader("Detected Clothing Items:")
|
|
|
191 |
st.session_state.detections = []
|
192 |
st.session_state.selected_category = None
|
193 |
|
194 |
+
# Text search
|
195 |
+
st.sidebar.title("Text Search")
|
196 |
+
query_text = st.sidebar.text_input("Enter search text:")
|
197 |
+
if st.sidebar.button("Search by Text"):
|
198 |
+
if query_text:
|
199 |
+
text_embedding = get_text_embedding(query_text)
|
200 |
+
similar_images = find_similar_images(text_embedding, collection)
|
201 |
+
st.sidebar.subheader("Similar Items:")
|
202 |
+
for img in similar_images:
|
203 |
+
st.sidebar.image(img['info']['image_url'], use_column_width=True)
|
204 |
+
st.sidebar.write(f"Name: {img['info']['name']}")
|
205 |
+
st.sidebar.write(f"Brand: {img['info']['brand']}")
|
206 |
+
st.sidebar.write(f"Category: {img['info']['category']}")
|
207 |
+
st.sidebar.write(f"Price: {img['info']['price']}")
|
208 |
+
st.sidebar.write(f"Discount: {img['info']['discount']}%")
|
209 |
+
st.sidebar.write(f"Similarity: {img['similarity']:.2f}")
|
210 |
+
st.sidebar.write("---")
|
211 |
+
else:
|
212 |
+
st.sidebar.warning("Please enter a search text.")
|
|