leedoming commited on
Commit
543b03f
1 Parent(s): 301faf2

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +190 -0
app.py ADDED
@@ -0,0 +1,190 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import open_clip
3
+ import torch
4
+ import requests
5
+ from PIL import Image
6
+ from io import BytesIO
7
+ import time
8
+ import json
9
+ import numpy as np
10
+ from ultralytics import YOLO
11
+ import cv2
12
+
13
+ # Load CLIP model and tokenizer
14
+ @st.cache_resource
15
+ def load_clip_model():
16
+ model, preprocess_train, preprocess_val = open_clip.create_model_and_transforms('hf-hub:Marqo/marqo-fashionSigLIP')
17
+ tokenizer = open_clip.get_tokenizer('hf-hub:Marqo/marqo-fashionSigLIP')
18
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
19
+ model.to(device)
20
+ return model, preprocess_val, tokenizer, device
21
+
22
+ clip_model, preprocess_val, tokenizer, device = load_clip_model()
23
+
24
+ # Load YOLOv8 model
25
+ @st.cache_resource
26
+ def load_yolo_model():
27
+ return YOLO("./yolov8m.pt")
28
+
29
+ yolo_model = load_yolo_model()
30
+
31
+ # Load and process data
32
+ @st.cache_data
33
+ def load_data():
34
+ with open('./musinsa-final.json', 'r', encoding='utf-8') as f:
35
+ return json.load(f)
36
+
37
+ data = load_data()
38
+
39
+ # Helper functions
40
+ def load_image_from_url(url, max_retries=3):
41
+ for attempt in range(max_retries):
42
+ try:
43
+ response = requests.get(url, timeout=10)
44
+ response.raise_for_status()
45
+ img = Image.open(BytesIO(response.content)).convert('RGB')
46
+ return img
47
+ except (requests.RequestException, Image.UnidentifiedImageError) as e:
48
+ if attempt < max_retries - 1:
49
+ time.sleep(1)
50
+ else:
51
+ return None
52
+
53
+ def get_image_embedding(image):
54
+ image_tensor = preprocess_val(image).unsqueeze(0).to(device)
55
+ with torch.no_grad():
56
+ image_features = clip_model.encode_image(image_tensor)
57
+ image_features /= image_features.norm(dim=-1, keepdim=True)
58
+ return image_features.cpu().numpy()
59
+
60
+ @st.cache_data
61
+ def process_database():
62
+ database_embeddings = []
63
+ database_info = []
64
+
65
+ for item in data:
66
+ image_url = item['이미지 링크'][0]
67
+ image = load_image_from_url(image_url)
68
+ if image is not None:
69
+ embedding = get_image_embedding(image)
70
+ database_embeddings.append(embedding)
71
+ database_info.append({
72
+ 'id': item['\ufeff상품 ID'],
73
+ 'category': item['카테고리'],
74
+ 'brand': item['브랜드명'],
75
+ 'name': item['제품명'],
76
+ 'price': item['정가'],
77
+ 'discount': item['할인율'],
78
+ 'image_url': image_url
79
+ })
80
+ else:
81
+ st.warning(f"Skipping item {item['상품 ID']} due to image loading failure")
82
+
83
+ if database_embeddings:
84
+ return np.vstack(database_embeddings), database_info
85
+ else:
86
+ st.error("No valid embeddings were generated.")
87
+ return None, None
88
+
89
+ database_embeddings, database_info = process_database()
90
+
91
+ def get_text_embedding(text):
92
+ text_tokens = tokenizer([text]).to(device)
93
+ with torch.no_grad():
94
+ text_features = clip_model.encode_text(text_tokens)
95
+ text_features /= text_features.norm(dim=-1, keepdim=True)
96
+ return text_features.cpu().numpy()
97
+
98
+ def find_similar_images(query_embedding, top_k=5):
99
+ similarities = np.dot(database_embeddings, query_embedding.T).squeeze()
100
+ top_indices = np.argsort(similarities)[::-1][:top_k]
101
+ results = []
102
+ for idx in top_indices:
103
+ results.append({
104
+ 'info': database_info[idx],
105
+ 'similarity': similarities[idx]
106
+ })
107
+ return results
108
+
109
+ def detect_clothing(image):
110
+ results = yolo_model(image)
111
+ detections = results[0].boxes.data.cpu().numpy()
112
+ categories = []
113
+ for detection in detections:
114
+ x1, y1, x2, y2, conf, cls = detection
115
+ category = yolo_model.names[int(cls)]
116
+ if category in ['top', 'bottom', 'hat', 'shoes']: # Add more categories as needed
117
+ categories.append({
118
+ 'category': category,
119
+ 'bbox': [int(x1), int(y1), int(x2), int(y2)],
120
+ 'confidence': conf
121
+ })
122
+ return categories
123
+
124
+ def crop_image(image, bbox):
125
+ return image.crop((bbox[0], bbox[1], bbox[2], bbox[3]))
126
+
127
+ # Streamlit app
128
+ st.title("Advanced Fashion Search App")
129
+
130
+ search_type = st.radio("Search by:", ("Image URL", "Text"))
131
+
132
+ if search_type == "Image URL":
133
+ query_image_url = st.text_input("Enter image URL:")
134
+ if st.button("Detect Clothing"):
135
+ if query_image_url:
136
+ query_image = load_image_from_url(query_image_url)
137
+ if query_image is not None:
138
+ st.image(query_image, caption="Query Image", use_column_width=True)
139
+ detections = detect_clothing(query_image)
140
+
141
+ if detections:
142
+ st.subheader("Detected Clothing Items:")
143
+ selected_category = st.selectbox("Select a category to search:",
144
+ [f"{d['category']} (Confidence: {d['confidence']:.2f})" for d in detections])
145
+
146
+ if st.button("Search Similar Items"):
147
+ selected_detection = next(d for d in detections if f"{d['category']} (Confidence: {d['confidence']:.2f})" == selected_category)
148
+ cropped_image = crop_image(query_image, selected_detection['bbox'])
149
+ query_embedding = get_image_embedding(cropped_image)
150
+ similar_images = find_similar_images(query_embedding)
151
+
152
+ st.subheader("Similar Items:")
153
+ for img in similar_images:
154
+ col1, col2 = st.columns(2)
155
+ with col1:
156
+ st.image(img['info']['image_url'], use_column_width=True)
157
+ with col2:
158
+ st.write(f"Name: {img['info']['name']}")
159
+ st.write(f"Brand: {img['info']['brand']}")
160
+ st.write(f"Category: {img['info']['category']}")
161
+ st.write(f"Price: {img['info']['price']}")
162
+ st.write(f"Discount: {img['info']['discount']}%")
163
+ st.write(f"Similarity: {img['similarity']:.2f}")
164
+ else:
165
+ st.warning("No clothing items detected in the image.")
166
+ else:
167
+ st.error("Failed to load the image. Please try another URL.")
168
+ else:
169
+ st.warning("Please enter an image URL.")
170
+
171
+ else: # Text search
172
+ query_text = st.text_input("Enter search text:")
173
+ if st.button("Search by Text"):
174
+ if query_text:
175
+ text_embedding = get_text_embedding(query_text)
176
+ similar_images = find_similar_images(text_embedding)
177
+ st.subheader("Similar Items:")
178
+ for img in similar_images:
179
+ col1, col2 = st.columns(2)
180
+ with col1:
181
+ st.image(img['info']['image_url'], use_column_width=True)
182
+ with col2:
183
+ st.write(f"Name: {img['info']['name']}")
184
+ st.write(f"Brand: {img['info']['brand']}")
185
+ st.write(f"Category: {img['info']['category']}")
186
+ st.write(f"Price: {img['info']['price']}")
187
+ st.write(f"Discount: {img['info']['discount']}%")
188
+ st.write(f"Similarity: {img['similarity']:.2f}")
189
+ else:
190
+ st.warning("Please enter a search text.")