itda-segment / app.py
leedoming's picture
Upload 14 files
2dba380 verified
raw
history blame
18.3 kB
import streamlit as st
import open_clip
import torch
from PIL import Image
import numpy as np
from transformers import pipeline
import chromadb
import logging
import io
import requests
from concurrent.futures import ThreadPoolExecutor
# ๋กœ๊น… ์„ค์ •
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
# Initialize session state
if 'image' not in st.session_state:
st.session_state.image = None
if 'detected_items' not in st.session_state:
st.session_state.detected_items = None
if 'selected_item_index' not in st.session_state:
st.session_state.selected_item_index = None
if 'upload_state' not in st.session_state:
st.session_state.upload_state = 'initial'
if 'search_clicked' not in st.session_state:
st.session_state.search_clicked = False
# Load models
@st.cache_resource
def load_models():
try:
# CLIP ๋ชจ๋ธ
model, _, preprocess_val = open_clip.create_model_and_transforms('hf-hub:Marqo/marqo-fashionSigLIP')
# ์„ธ๊ทธ๋ฉ˜ํ…Œ์ด์…˜ ๋ชจ๋ธ
segmenter = pipeline(model="mattmdjaga/segformer_b2_clothes")
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
return model, preprocess_val, segmenter, device
except Exception as e:
logger.error(f"Error loading models: {e}")
raise
# ๋ชจ๋ธ ๋กœ๋“œ
clip_model, preprocess_val, segmenter, device = load_models()
# ChromaDB ์„ค์ •
client = chromadb.PersistentClient(path="./clothesDB_11GmarketMusinsa")
collection = client.get_collection(name="clothes")
def process_segmentation(image):
"""Segmentation processing"""
try:
# pipeline ์ถœ๋ ฅ ๊ฒฐ๊ณผ ์ง์ ‘ ์ฒ˜๋ฆฌ
output = segmenter(image)
if not output:
logger.warning("No segments found in image")
return None
# ๊ฐ ์„ธ๊ทธ๋จผํŠธ์˜ ๋งˆ์Šคํฌ ํฌ๊ธฐ ๊ณ„์‚ฐ
segment_sizes = [np.sum(seg['mask']) for seg in output]
if not segment_sizes:
return None
# ๊ฐ€์žฅ ํฐ ์„ธ๊ทธ๋จผํŠธ ์„ ํƒ
largest_idx = np.argmax(segment_sizes)
mask = output[largest_idx]['mask']
# ๋งˆ์Šคํฌ๊ฐ€ numpy array๊ฐ€ ์•„๋‹Œ ๊ฒฝ์šฐ ๋ณ€ํ™˜
if not isinstance(mask, np.ndarray):
mask = np.array(mask)
# ๋งˆ์Šคํฌ๊ฐ€ 2D๊ฐ€ ์•„๋‹Œ ๊ฒฝ์šฐ ์ฒซ ๋ฒˆ์งธ ์ฑ„๋„ ์‚ฌ์šฉ
if len(mask.shape) > 2:
mask = mask[:, :, 0]
# bool ๋งˆ์Šคํฌ๋ฅผ float๋กœ ๋ณ€ํ™˜
mask = mask.astype(float)
logger.info(f"Successfully created mask with shape {mask.shape}")
return mask
except Exception as e:
logger.error(f"Segmentation error: {str(e)}")
import traceback
logger.error(traceback.format_exc())
return None
def download_and_process_image(image_url, metadata_id):
"""Download image from URL and apply segmentation"""
try:
response = requests.get(image_url, timeout=10) # timeout ์ถ”๊ฐ€
if response.status_code != 200:
logger.error(f"Failed to download image {metadata_id}: HTTP {response.status_code}")
return None
image = Image.open(io.BytesIO(response.content)).convert('RGB')
logger.info(f"Successfully downloaded image {metadata_id}")
mask = process_segmentation(image)
if mask is not None:
features = extract_features(image, mask)
logger.info(f"Successfully extracted features for image {metadata_id}")
return features
logger.warning(f"No valid mask found for image {metadata_id}")
return None
except Exception as e:
logger.error(f"Error processing image {metadata_id}: {str(e)}")
import traceback
logger.error(traceback.format_exc())
return None
def update_db_with_segmentation():
"""DB์˜ ๋ชจ๋“  ์ด๋ฏธ์ง€์— ๋Œ€ํ•ด segmentation์„ ์ ์šฉํ•˜๊ณ  feature๋ฅผ ์—…๋ฐ์ดํŠธ"""
try:
logger.info("Starting database update with segmentation")
# ์ƒˆ๋กœ์šด collection ์ƒ์„ฑ
try:
client.delete_collection("clothes_segmented")
logger.info("Deleted existing segmented collection")
except:
logger.info("No existing segmented collection to delete")
new_collection = client.create_collection(
name="clothes_segmented",
metadata={"description": "Clothes collection with segmentation-based features"}
)
logger.info("Created new segmented collection")
# ๊ธฐ์กด collection์—์„œ ๋ฉ”ํƒ€๋ฐ์ดํ„ฐ๋งŒ ๊ฐ€์ ธ์˜ค๊ธฐ
try:
all_items = collection.get(include=['metadatas'])
total_items = len(all_items['metadatas'])
logger.info(f"Found {total_items} items in database")
except Exception as e:
logger.error(f"Error getting items from collection: {str(e)}")
# ์—๋Ÿฌ ๋ฐœ์ƒ ์‹œ ๋นˆ ๋ฆฌ์ŠคํŠธ๋กœ ์ดˆ๊ธฐํ™”
all_items = {'metadatas': []}
total_items = 0
# ์ง„ํ–‰ ์ƒํ™ฉ ํ‘œ์‹œ๋ฅผ ์œ„ํ•œ progress bar
progress_bar = st.progress(0)
status_text = st.empty()
successful_updates = 0
failed_updates = 0
with ThreadPoolExecutor(max_workers=4) as executor:
futures = []
# ์ด๋ฏธ์ง€ URL์ด ์žˆ๋Š” ํ•ญ๋ชฉ๋งŒ ์ฒ˜๋ฆฌ
valid_items = [m for m in all_items['metadatas'] if 'image_url' in m]
for metadata in valid_items:
future = executor.submit(
download_and_process_image,
metadata['image_url'],
metadata.get('id', 'unknown')
)
futures.append((metadata, future))
# ๊ฒฐ๊ณผ ์ฒ˜๋ฆฌ ๋ฐ ์ƒˆ DB์— ์ €์žฅ
for idx, (metadata, future) in enumerate(futures):
try:
new_features = future.result()
if new_features is not None:
item_id = metadata.get('id', str(hash(metadata['image_url'])))
try:
new_collection.add(
embeddings=[new_features.tolist()],
metadatas=[metadata],
ids=[item_id]
)
successful_updates += 1
logger.info(f"Successfully added item {item_id}")
except Exception as e:
logger.error(f"Error adding item to new collection: {str(e)}")
failed_updates += 1
else:
failed_updates += 1
# ์ง„ํ–‰ ์ƒํ™ฉ ์—…๋ฐ์ดํŠธ
progress = (idx + 1) / len(futures)
progress_bar.progress(progress)
status_text.text(f"Processing: {idx + 1}/{len(futures)} items. Success: {successful_updates}, Failed: {failed_updates}")
except Exception as e:
logger.error(f"Error processing item: {str(e)}")
failed_updates += 1
continue
# ์ตœ์ข… ๊ฒฐ๊ณผ ํ‘œ์‹œ
status_text.text(f"Update completed. Successfully processed: {successful_updates}, Failed: {failed_updates}")
logger.info(f"Database update completed. Successful: {successful_updates}, Failed: {failed_updates}")
# ์„ฑ๊ณต์ ์œผ๋กœ ์ฒ˜๋ฆฌ๋œ ํ•ญ๋ชฉ์ด ์žˆ๋Š”์ง€ ํ™•์ธ
if successful_updates > 0:
return True
else:
logger.error("No items were successfully processed")
return False
except Exception as e:
logger.error(f"Database update error: {str(e)}")
import traceback
logger.error(traceback.format_exc())
return False
def extract_features(image, mask=None):
"""Extract CLIP features with segmentation mask"""
try:
if mask is not None:
img_array = np.array(image)
mask = np.expand_dims(mask, axis=2)
masked_img = img_array * mask
masked_img[mask[:,:,0] == 0] = 255 # ๋ฐฐ๊ฒฝ์„ ํฐ์ƒ‰์œผ๋กœ
image = Image.fromarray(masked_img.astype(np.uint8))
image_tensor = preprocess_val(image).unsqueeze(0).to(device)
with torch.no_grad():
features = clip_model.encode_image(image_tensor)
features /= features.norm(dim=-1, keepdim=True)
return features.cpu().numpy().flatten()
except Exception as e:
logger.error(f"Feature extraction error: {e}")
raise
def search_similar_items(features, top_k=10):
"""Search similar items using segmentation-based features"""
try:
# ์„ธ๊ทธ๋ฉ˜ํ…Œ์ด์…˜์ด ์ ์šฉ๋œ collection์ด ์žˆ๋Š”์ง€ ํ™•์ธ
try:
search_collection = client.get_collection("clothes_segmented")
logger.info("Using segmented collection for search")
except:
# ์—†์œผ๋ฉด ๊ธฐ์กด collection ์‚ฌ์šฉ
search_collection = collection
logger.info("Using original collection for search")
results = search_collection.query(
query_embeddings=[features.tolist()],
n_results=top_k,
include=['metadatas', 'distances']
)
if not results or not results['metadatas'] or not results['distances']:
logger.warning("No results returned from ChromaDB")
return []
similar_items = []
for metadata, distance in zip(results['metadatas'][0], results['distances'][0]):
try:
similarity_score = 1 / (1 + float(distance))
item_data = metadata.copy()
item_data['similarity_score'] = similarity_score
similar_items.append(item_data)
except Exception as e:
logger.error(f"Error processing search result: {str(e)}")
continue
similar_items.sort(key=lambda x: x['similarity_score'], reverse=True)
return similar_items
except Exception as e:
logger.error(f"Search error: {str(e)}")
return []
def show_similar_items(similar_items):
"""Display similar items in a structured format with similarity scores"""
if not similar_items:
st.warning("No similar items found.")
return
st.subheader("Similar Items:")
# ๊ฒฐ๊ณผ๋ฅผ 2์—ด๋กœ ํ‘œ์‹œ
items_per_row = 2
for i in range(0, len(similar_items), items_per_row):
cols = st.columns(items_per_row)
for j, col in enumerate(cols):
if i + j < len(similar_items):
item = similar_items[i + j]
with col:
try:
if 'image_url' in item:
st.image(item['image_url'], use_column_width=True)
# ์œ ์‚ฌ๋„ ์ ์ˆ˜๋ฅผ ํผ์„ผํŠธ๋กœ ํ‘œ์‹œ
similarity_percent = item['similarity_score'] * 100
st.markdown(f"**Similarity: {similarity_percent:.1f}%**")
st.write(f"Brand: {item.get('brand', 'Unknown')}")
name = item.get('name', 'Unknown')
if len(name) > 50: # ๊ธด ์ด๋ฆ„์€ ์ค„์ž„
name = name[:47] + "..."
st.write(f"Name: {name}")
# ๊ฐ€๊ฒฉ ์ •๋ณด ํ‘œ์‹œ
price = item.get('price', 0)
if isinstance(price, (int, float)):
st.write(f"Price: {price:,}์›")
else:
st.write(f"Price: {price}")
# ํ• ์ธ ์ •๋ณด๊ฐ€ ์žˆ๋Š” ๊ฒฝ์šฐ
if 'discount' in item and item['discount']:
st.write(f"Discount: {item['discount']}%")
if 'original_price' in item:
st.write(f"Original: {item['original_price']:,}์›")
st.divider() # ๊ตฌ๋ถ„์„  ์ถ”๊ฐ€
except Exception as e:
logger.error(f"Error displaying item: {e}")
st.error("Error displaying this item")
def process_search(image, mask, num_results):
"""์œ ์‚ฌ ์•„์ดํ…œ ๊ฒ€์ƒ‰ ์ฒ˜๋ฆฌ"""
try:
with st.spinner("Extracting features..."):
features = extract_features(image, mask)
with st.spinner("Finding similar items..."):
similar_items = search_similar_items(features, top_k=num_results)
return similar_items
except Exception as e:
logger.error(f"Search processing error: {e}")
return None
# Callback functions
def handle_file_upload():
if st.session_state.uploaded_file is not None:
image = Image.open(st.session_state.uploaded_file).convert('RGB')
st.session_state.image = image
st.session_state.upload_state = 'image_uploaded'
st.rerun()
def handle_detection():
if st.session_state.image is not None:
detected_items = process_segmentation(st.session_state.image)
st.session_state.detected_items = detected_items
st.session_state.upload_state = 'items_detected'
st.rerun()
def handle_search():
st.session_state.search_clicked = True
def admin_interface():
st.title("Admin Interface - DB Update")
if st.button("Update DB with Segmentation"):
with st.spinner("Updating database with segmentation... This may take a while..."):
success = update_db_with_segmentation()
if success:
st.success("Database successfully updated with segmentation-based features!")
else:
st.error("Failed to update database. Please check the logs.")
def main():
st.title("Fashion Search App")
# Admin controls in sidebar
st.sidebar.title("Admin Controls")
if st.sidebar.checkbox("Show Admin Interface"):
admin_interface()
st.divider()
# ํŒŒ์ผ ์—…๋กœ๋” (upload_state๊ฐ€ initial์ผ ๋•Œ๋งŒ ํ‘œ์‹œ)
if st.session_state.upload_state == 'initial':
uploaded_file = st.file_uploader("Upload an image", type=['png', 'jpg', 'jpeg'],
key='uploaded_file', on_change=handle_file_upload)
# ์ด๋ฏธ์ง€๊ฐ€ ์—…๋กœ๋“œ๋œ ์ƒํƒœ
if st.session_state.image is not None:
st.image(st.session_state.image, caption="Uploaded Image", use_column_width=True)
if st.session_state.detected_items is None:
if st.button("Detect Items", key='detect_button', on_click=handle_detection):
pass
# ๊ฒ€์ถœ๋œ ์•„์ดํ…œ ํ‘œ์‹œ
if st.session_state.detected_items:
# ๊ฐ์ง€๋œ ์•„์ดํ…œ๋“ค์„ 2์—ด๋กœ ํ‘œ์‹œ
cols = st.columns(2)
for idx, item in enumerate(st.session_state.detected_items):
with cols[idx % 2]:
mask = item['mask']
masked_img = np.array(st.session_state.image) * np.expand_dims(mask, axis=2)
st.image(masked_img.astype(np.uint8), caption=f"Detected {item['label']}")
st.write(f"Item {idx + 1}: {item['label']}")
st.write(f"Confidence: {item['score']*100:.1f}%")
# ์•„์ดํ…œ ์„ ํƒ
selected_idx = st.selectbox(
"Select item to search:",
range(len(st.session_state.detected_items)),
format_func=lambda i: f"{st.session_state.detected_items[i]['label']}",
key='item_selector'
)
# ๊ฒ€์ƒ‰ ์ปจํŠธ๋กค
search_col1, search_col2 = st.columns([1, 2])
with search_col1:
search_clicked = st.button("Search Similar Items",
key='search_button',
type="primary")
with search_col2:
num_results = st.slider("Number of results:",
min_value=1,
max_value=20,
value=5,
key='num_results')
# ๊ฒ€์ƒ‰ ๊ฒฐ๊ณผ ์ฒ˜๋ฆฌ
if search_clicked or st.session_state.get('search_clicked', False):
st.session_state.search_clicked = True
selected_mask = st.session_state.detected_items[selected_idx]['mask']
# ๊ฒ€์ƒ‰ ๊ฒฐ๊ณผ๋ฅผ ์„ธ์…˜ ์ƒํƒœ์— ์ €์žฅ
if 'search_results' not in st.session_state:
similar_items = process_search(st.session_state.image, selected_mask, num_results)
st.session_state.search_results = similar_items
# ์ €์žฅ๋œ ๊ฒ€์ƒ‰ ๊ฒฐ๊ณผ ํ‘œ์‹œ
if st.session_state.search_results:
show_similar_items(st.session_state.search_results)
else:
st.warning("No similar items found.")
# ์ƒˆ ๊ฒ€์ƒ‰ ๋ฒ„ํŠผ
if st.button("Start New Search", key='new_search'):
# ๋ชจ๋“  ์ƒํƒœ ์ดˆ๊ธฐํ™”
for key in list(st.session_state.keys()):
del st.session_state[key]
st.rerun()
if __name__ == "__main__":
main()