Spaces:
Runtime error
Runtime error
Update app.py
Browse files
app.py
CHANGED
@@ -11,15 +11,12 @@ import cv2
|
|
11 |
import chromadb
|
12 |
from transformers import SegformerImageProcessor, AutoModelForSemanticSegmentation
|
13 |
import torch.nn as nn
|
14 |
-
import
|
15 |
-
|
16 |
-
# Suppress specific warnings
|
17 |
-
warnings.filterwarnings("ignore", category=UserWarning, module="transformers.utils.deprecation")
|
18 |
|
19 |
# Load CLIP model and tokenizer
|
20 |
@st.cache_resource
|
21 |
def load_clip_model():
|
22 |
-
model,
|
23 |
tokenizer = open_clip.get_tokenizer('hf-hub:Marqo/marqo-fashionSigLIP')
|
24 |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
25 |
model.to(device)
|
@@ -30,12 +27,21 @@ clip_model, preprocess_val, tokenizer, device = load_clip_model()
|
|
30 |
# Load SegFormer model
|
31 |
@st.cache_resource
|
32 |
def load_segformer_model():
|
33 |
-
processor = SegformerImageProcessor.from_pretrained("
|
34 |
-
model = AutoModelForSemanticSegmentation.from_pretrained("
|
35 |
return model, processor
|
36 |
|
37 |
segformer_model, segformer_processor = load_segformer_model()
|
38 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
39 |
# Helper functions
|
40 |
def load_image_from_url(url, max_retries=3):
|
41 |
for attempt in range(max_retries):
|
@@ -50,15 +56,6 @@ def load_image_from_url(url, max_retries=3):
|
|
50 |
else:
|
51 |
return None
|
52 |
|
53 |
-
# Load ChromaDB
|
54 |
-
@st.cache_resource
|
55 |
-
def load_chromadb():
|
56 |
-
client = chromadb.PersistentClient(path="./clothesDB")
|
57 |
-
collection = client.get_collection(name="clothes_items_ver3")
|
58 |
-
return collection
|
59 |
-
|
60 |
-
collection = load_chromadb()
|
61 |
-
|
62 |
def get_image_embedding(image):
|
63 |
image_tensor = preprocess_val(image).unsqueeze(0).to(device)
|
64 |
with torch.no_grad():
|
@@ -73,35 +70,23 @@ def get_text_embedding(text):
|
|
73 |
text_features /= text_features.norm(dim=-1, keepdim=True)
|
74 |
return text_features.cpu().numpy()
|
75 |
|
76 |
-
def get_all_embeddings_from_collection(collection):
|
77 |
-
all_embeddings = collection.get(include=['embeddings'])['embeddings']
|
78 |
-
return np.array(all_embeddings)
|
79 |
-
|
80 |
-
def get_metadata_from_ids(collection, ids):
|
81 |
-
results = collection.get(ids=ids)
|
82 |
-
return results['metadatas']
|
83 |
-
|
84 |
def find_similar_images(query_embedding, collection, top_k=5):
|
85 |
-
database_embeddings =
|
86 |
similarities = np.dot(database_embeddings, query_embedding.T).squeeze()
|
87 |
top_indices = np.argsort(similarities)[::-1][:top_k]
|
88 |
|
89 |
all_data = collection.get(include=['metadatas'])['metadatas']
|
90 |
|
91 |
-
|
92 |
-
|
93 |
-
|
94 |
-
|
95 |
-
results.append({
|
96 |
-
'info': metadata,
|
97 |
-
'similarity': similarities[top_indices[idx]]
|
98 |
-
})
|
99 |
return results
|
100 |
|
101 |
def segment_clothing(image):
|
102 |
inputs = segformer_processor(images=image, return_tensors="pt")
|
103 |
outputs = segformer_model(**inputs)
|
104 |
-
logits = outputs.logits
|
105 |
|
106 |
upsampled_logits = nn.functional.interpolate(
|
107 |
logits,
|
@@ -110,7 +95,7 @@ def segment_clothing(image):
|
|
110 |
align_corners=False,
|
111 |
)
|
112 |
|
113 |
-
pred_seg = upsampled_logits.argmax(dim=1)[0].numpy()
|
114 |
|
115 |
categories = segformer_model.config.id2label
|
116 |
segmented_items = []
|
@@ -127,7 +112,23 @@ def segment_clothing(image):
|
|
127 |
'mask': mask
|
128 |
})
|
129 |
|
130 |
-
return segmented_items, pred_seg
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
131 |
|
132 |
def crop_image(image, bbox):
|
133 |
return image.crop((bbox[0], bbox[1], bbox[2], bbox[3]))
|
@@ -153,7 +154,7 @@ if st.session_state.step == 'input':
|
|
153 |
query_image = load_image_from_url(st.session_state.query_image_url)
|
154 |
if query_image is not None:
|
155 |
st.session_state.query_image = query_image
|
156 |
-
st.session_state.segmentations, st.session_state.pred_seg = segment_clothing(query_image)
|
157 |
if st.session_state.segmentations:
|
158 |
st.session_state.step = 'select_category'
|
159 |
else:
|
@@ -168,8 +169,10 @@ elif st.session_state.step == 'select_category':
|
|
168 |
with col1:
|
169 |
st.image(st.session_state.query_image, caption="Original Image", use_column_width=True)
|
170 |
with col2:
|
171 |
-
|
172 |
-
|
|
|
|
|
173 |
st.subheader("Segmented Clothing Items:")
|
174 |
|
175 |
for segmentation in st.session_state.segmentations:
|
|
|
11 |
import chromadb
|
12 |
from transformers import SegformerImageProcessor, AutoModelForSemanticSegmentation
|
13 |
import torch.nn as nn
|
14 |
+
import matplotlib.pyplot as plt
|
|
|
|
|
|
|
15 |
|
16 |
# Load CLIP model and tokenizer
|
17 |
@st.cache_resource
|
18 |
def load_clip_model():
|
19 |
+
model, preprocess_train, preprocess_val = open_clip.create_model_and_transforms('hf-hub:Marqo/marqo-fashionSigLIP')
|
20 |
tokenizer = open_clip.get_tokenizer('hf-hub:Marqo/marqo-fashionSigLIP')
|
21 |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
22 |
model.to(device)
|
|
|
27 |
# Load SegFormer model
|
28 |
@st.cache_resource
|
29 |
def load_segformer_model():
|
30 |
+
processor = SegformerImageProcessor.from_pretrained("sayeed99/segformer-b3-fashion")
|
31 |
+
model = AutoModelForSemanticSegmentation.from_pretrained("sayeed99/segformer-b3-fashion")
|
32 |
return model, processor
|
33 |
|
34 |
segformer_model, segformer_processor = load_segformer_model()
|
35 |
|
36 |
+
# Load ChromaDB
|
37 |
+
@st.cache_resource
|
38 |
+
def load_chromadb():
|
39 |
+
client = chromadb.PersistentClient(path="./clothesDB")
|
40 |
+
collection = client.get_collection(name="clothes_items_ver3")
|
41 |
+
return collection
|
42 |
+
|
43 |
+
collection = load_chromadb()
|
44 |
+
|
45 |
# Helper functions
|
46 |
def load_image_from_url(url, max_retries=3):
|
47 |
for attempt in range(max_retries):
|
|
|
56 |
else:
|
57 |
return None
|
58 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
59 |
def get_image_embedding(image):
|
60 |
image_tensor = preprocess_val(image).unsqueeze(0).to(device)
|
61 |
with torch.no_grad():
|
|
|
70 |
text_features /= text_features.norm(dim=-1, keepdim=True)
|
71 |
return text_features.cpu().numpy()
|
72 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
73 |
def find_similar_images(query_embedding, collection, top_k=5):
|
74 |
+
database_embeddings = np.array(collection.get(include=['embeddings'])['embeddings'])
|
75 |
similarities = np.dot(database_embeddings, query_embedding.T).squeeze()
|
76 |
top_indices = np.argsort(similarities)[::-1][:top_k]
|
77 |
|
78 |
all_data = collection.get(include=['metadatas'])['metadatas']
|
79 |
|
80 |
+
results = [
|
81 |
+
{'info': all_data[idx], 'similarity': similarities[idx]}
|
82 |
+
for idx in top_indices
|
83 |
+
]
|
|
|
|
|
|
|
|
|
84 |
return results
|
85 |
|
86 |
def segment_clothing(image):
|
87 |
inputs = segformer_processor(images=image, return_tensors="pt")
|
88 |
outputs = segformer_model(**inputs)
|
89 |
+
logits = outputs.logits
|
90 |
|
91 |
upsampled_logits = nn.functional.interpolate(
|
92 |
logits,
|
|
|
95 |
align_corners=False,
|
96 |
)
|
97 |
|
98 |
+
pred_seg = upsampled_logits.argmax(dim=1)[0].cpu().numpy()
|
99 |
|
100 |
categories = segformer_model.config.id2label
|
101 |
segmented_items = []
|
|
|
112 |
'mask': mask
|
113 |
})
|
114 |
|
115 |
+
return segmented_items, pred_seg, categories
|
116 |
+
|
117 |
+
def visualize_segmentation(pred_seg, categories):
|
118 |
+
plt.figure(figsize=(10, 10))
|
119 |
+
plt.imshow(pred_seg, cmap='jet')
|
120 |
+
plt.colorbar(label='Category ID')
|
121 |
+
plt.title("Segmentation Result")
|
122 |
+
plt.axis('off')
|
123 |
+
|
124 |
+
# Add legend
|
125 |
+
unique_classes = np.unique(pred_seg)
|
126 |
+
legend_elements = [plt.Rectangle((0,0),1,1, color=plt.cm.jet(category_id/len(categories)),
|
127 |
+
label=f"{category_id}: {categories[category_id]}")
|
128 |
+
for category_id in unique_classes if category_id in categories]
|
129 |
+
plt.legend(handles=legend_elements, loc='center left', bbox_to_anchor=(1, 0.5))
|
130 |
+
|
131 |
+
return plt
|
132 |
|
133 |
def crop_image(image, bbox):
|
134 |
return image.crop((bbox[0], bbox[1], bbox[2], bbox[3]))
|
|
|
154 |
query_image = load_image_from_url(st.session_state.query_image_url)
|
155 |
if query_image is not None:
|
156 |
st.session_state.query_image = query_image
|
157 |
+
st.session_state.segmentations, st.session_state.pred_seg, st.session_state.categories = segment_clothing(query_image)
|
158 |
if st.session_state.segmentations:
|
159 |
st.session_state.step = 'select_category'
|
160 |
else:
|
|
|
169 |
with col1:
|
170 |
st.image(st.session_state.query_image, caption="Original Image", use_column_width=True)
|
171 |
with col2:
|
172 |
+
seg_fig = visualize_segmentation(st.session_state.pred_seg, st.session_state.categories)
|
173 |
+
st.pyplot(seg_fig)
|
174 |
+
plt.close(seg_fig) # Prevent memory leaks
|
175 |
+
|
176 |
st.subheader("Segmented Clothing Items:")
|
177 |
|
178 |
for segmentation in st.session_state.segmentations:
|