leedoming commited on
Commit
bc3fea2
1 Parent(s): 0fec354

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +47 -41
app.py CHANGED
@@ -7,9 +7,9 @@ from io import BytesIO
7
  import time
8
  import json
9
  import numpy as np
10
- from ultralytics import YOLO
11
  import cv2
12
  import chromadb
 
13
 
14
  # Load CLIP model and tokenizer
15
  @st.cache_resource
@@ -22,12 +22,17 @@ def load_clip_model():
22
 
23
  clip_model, preprocess_val, tokenizer, device = load_clip_model()
24
 
25
- # Load YOLOv8 model
26
  @st.cache_resource
27
- def load_yolo_model():
28
- return YOLO("./best.pt")
 
 
29
 
30
- yolo_model = load_yolo_model()
 
 
 
31
 
32
  # Helper functions
33
  def load_image_from_url(url, max_retries=3):
@@ -42,6 +47,7 @@ def load_image_from_url(url, max_retries=3):
42
  time.sleep(1)
43
  else:
44
  return None
 
45
  #Load chromaDB
46
  client = chromadb.PersistentClient(path="./clothesDB")
47
  collection = client.get_collection(name="fashion_items_ver2")
@@ -85,27 +91,32 @@ def find_similar_images(query_embedding, collection, top_k=5):
85
  })
86
  return results
87
 
88
-
89
-
90
  def detect_clothing(image):
91
- results = yolo_model(image)
92
- detections = results[0].boxes.data.cpu().numpy()
 
 
 
 
93
  categories = []
94
- for detection in detections:
95
- x1, y1, x2, y2, conf, cls = detection
96
- category = yolo_model.names[int(cls)]
97
- if category in ['sunglass','hat','jacket','shirt','pants','shorts','skirt','dress','bag','shoe']:
98
  categories.append({
99
  'category': category,
100
- 'bbox': [int(x1), int(y1), int(x2), int(y2)],
101
- 'confidence': conf
102
  })
103
  return categories
104
 
105
  def crop_image(image, bbox):
106
  return image.crop((bbox[0], bbox[1], bbox[2], bbox[3]))
107
 
108
- # 세션 상태 초기화
 
 
 
109
  if 'step' not in st.session_state:
110
  st.session_state.step = 'input'
111
  if 'query_image_url' not in st.session_state:
@@ -115,10 +126,7 @@ if 'detections' not in st.session_state:
115
  if 'selected_category' not in st.session_state:
116
  st.session_state.selected_category = None
117
 
118
- # Streamlit app
119
- st.title("Advanced Fashion Search App")
120
-
121
- # 단계별 처리
122
  if st.session_state.step == 'input':
123
  st.session_state.query_image_url = st.text_input("Enter image URL:", st.session_state.query_image_url)
124
  if st.button("Detect Clothing"):
@@ -136,7 +144,6 @@ if st.session_state.step == 'input':
136
  else:
137
  st.warning("Please enter an image URL.")
138
 
139
- # Update the 'select_category' step
140
  elif st.session_state.step == 'select_category':
141
  st.image(st.session_state.query_image, caption="Query Image", use_column_width=True)
142
  st.subheader("Detected Clothing Items:")
@@ -184,23 +191,22 @@ elif st.session_state.step == 'show_results':
184
  st.session_state.detections = []
185
  st.session_state.selected_category = None
186
 
187
- else: # Text search
188
- query_text = st.text_input("Enter search text:")
189
- if st.button("Search by Text"):
190
- if query_text:
191
- text_embedding = get_text_embedding(query_text)
192
- similar_images = find_similar_images(text_embedding, collection)
193
- st.subheader("Similar Items:")
194
- for img in similar_images:
195
- col1, col2 = st.columns(2)
196
- with col1:
197
- st.image(img['info']['image_url'], use_column_width=True)
198
- with col2:
199
- st.write(f"Name: {img['info']['name']}")
200
- st.write(f"Brand: {img['info']['brand']}")
201
- st.write(f"Category: {img['info']['category']}")
202
- st.write(f"Price: {img['info']['price']}")
203
- st.write(f"Discount: {img['info']['discount']}%")
204
- st.write(f"Similarity: {img['similarity']:.2f}")
205
- else:
206
- st.warning("Please enter a search text.")
 
7
  import time
8
  import json
9
  import numpy as np
 
10
  import cv2
11
  import chromadb
12
+ from transformers import YolosImageProcessor, YolosForObjectDetection
13
 
14
  # Load CLIP model and tokenizer
15
  @st.cache_resource
 
22
 
23
  clip_model, preprocess_val, tokenizer, device = load_clip_model()
24
 
25
+ # Load YOLOS model
26
  @st.cache_resource
27
+ def load_yolos_model():
28
+ processor = YolosImageProcessor.from_pretrained("valentinafeve/yolos-fashionpedia")
29
+ model = YolosForObjectDetection.from_pretrained("valentinafeve/yolos-fashionpedia")
30
+ return processor, model
31
 
32
+ yolos_processor, yolos_model = load_yolos_model()
33
+
34
+ # Define the categories
35
+ CATS = ['shirt, blouse', 'top, t-shirt, sweatshirt', 'sweater', 'cardigan', 'jacket', 'vest', 'pants', 'shorts', 'skirt', 'coat', 'dress', 'jumpsuit', 'cape', 'glasses', 'hat', 'headband, head covering, hair accessory', 'tie', 'glove', 'watch', 'belt', 'leg warmer', 'tights, stockings', 'sock', 'shoe', 'bag, wallet', 'scarf', 'umbrella', 'hood', 'collar', 'lapel', 'epaulette', 'sleeve', 'pocket', 'neckline', 'buckle', 'zipper', 'applique', 'bead', 'bow', 'flower', 'fringe', 'ribbon', 'rivet', 'ruffle', 'sequin', 'tassel']
36
 
37
  # Helper functions
38
  def load_image_from_url(url, max_retries=3):
 
47
  time.sleep(1)
48
  else:
49
  return None
50
+
51
  #Load chromaDB
52
  client = chromadb.PersistentClient(path="./clothesDB")
53
  collection = client.get_collection(name="fashion_items_ver2")
 
91
  })
92
  return results
93
 
 
 
94
  def detect_clothing(image):
95
+ inputs = yolos_processor(images=image, return_tensors="pt")
96
+ outputs = yolos_model(**inputs)
97
+
98
+ target_sizes = torch.tensor([image.size[::-1]])
99
+ results = yolos_processor.post_process_object_detection(outputs, threshold=0.1, target_sizes=target_sizes)[0]
100
+
101
  categories = []
102
+ for score, label, box in zip(results["scores"], results["labels"], results["boxes"]):
103
+ box = [int(i) for i in box.tolist()]
104
+ category = yolos_model.config.id2label[label.item()]
105
+ if category in CATS:
106
  categories.append({
107
  'category': category,
108
+ 'bbox': box,
109
+ 'confidence': score.item()
110
  })
111
  return categories
112
 
113
  def crop_image(image, bbox):
114
  return image.crop((bbox[0], bbox[1], bbox[2], bbox[3]))
115
 
116
+ # Streamlit app
117
+ st.title("Advanced Fashion Search App")
118
+
119
+ # Initialize session state
120
  if 'step' not in st.session_state:
121
  st.session_state.step = 'input'
122
  if 'query_image_url' not in st.session_state:
 
126
  if 'selected_category' not in st.session_state:
127
  st.session_state.selected_category = None
128
 
129
+ # Step-by-step processing
 
 
 
130
  if st.session_state.step == 'input':
131
  st.session_state.query_image_url = st.text_input("Enter image URL:", st.session_state.query_image_url)
132
  if st.button("Detect Clothing"):
 
144
  else:
145
  st.warning("Please enter an image URL.")
146
 
 
147
  elif st.session_state.step == 'select_category':
148
  st.image(st.session_state.query_image, caption="Query Image", use_column_width=True)
149
  st.subheader("Detected Clothing Items:")
 
191
  st.session_state.detections = []
192
  st.session_state.selected_category = None
193
 
194
+ # Text search
195
+ st.sidebar.title("Text Search")
196
+ query_text = st.sidebar.text_input("Enter search text:")
197
+ if st.sidebar.button("Search by Text"):
198
+ if query_text:
199
+ text_embedding = get_text_embedding(query_text)
200
+ similar_images = find_similar_images(text_embedding, collection)
201
+ st.sidebar.subheader("Similar Items:")
202
+ for img in similar_images:
203
+ st.sidebar.image(img['info']['image_url'], use_column_width=True)
204
+ st.sidebar.write(f"Name: {img['info']['name']}")
205
+ st.sidebar.write(f"Brand: {img['info']['brand']}")
206
+ st.sidebar.write(f"Category: {img['info']['category']}")
207
+ st.sidebar.write(f"Price: {img['info']['price']}")
208
+ st.sidebar.write(f"Discount: {img['info']['discount']}%")
209
+ st.sidebar.write(f"Similarity: {img['similarity']:.2f}")
210
+ st.sidebar.write("---")
211
+ else:
212
+ st.sidebar.warning("Please enter a search text.")