import streamlit as st import open_clip import torch from PIL import Image import numpy as np from transformers import pipeline import chromadb import logging # 로깅 설정 logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) # Initialize session state if 'image' not in st.session_state: st.session_state.image = None if 'detected_items' not in st.session_state: st.session_state.detected_items = None if 'selected_item_index' not in st.session_state: st.session_state.selected_item_index = None if 'upload_state' not in st.session_state: st.session_state.upload_state = 'initial' # Load models 안녕 @st.cache_resource def load_models(): try: # CLIP 모델 model, _, preprocess_val = open_clip.create_model_and_transforms('hf-hub:Marqo/marqo-fashionSigLIP') # 세그멘테이션 모델 segmenter = pipeline(model="mattmdjaga/segformer_b2_clothes") device = torch.device("cuda" if torch.cuda.is_available() else "cpu") model.to(device) return model, preprocess_val, segmenter, device except Exception as e: logger.error(f"Error loading models: {e}") raise # 모델 로드 clip_model, preprocess_val, segmenter, device = load_models() # ChromaDB 설정 client = chromadb.PersistentClient(path="./clothesDB_11GmarketMusinsa") collection = client.get_collection(name="clothes") def process_segmentation(image): """Segmentation processing 안녕하세요""" try: segments = segmenter(image) valid_items = [] for s in segments: mask_array = np.array(s['mask']) confidence = np.mean(mask_array) valid_items.append({ 'score': confidence, 'label': s['label'], 'mask': mask_array }) return valid_items except Exception as e: logger.error(f"Segmentation error: {e}") return [] def extract_features(image, mask=None): """Extract CLIP features""" try: if mask is not None: img_array = np.array(image) mask = np.expand_dims(mask, axis=2) masked_img = img_array * mask masked_img[mask[:,:,0] == 0] = 255 image = Image.fromarray(masked_img.astype(np.uint8)) image_tensor = preprocess_val(image).unsqueeze(0).to(device) with torch.no_grad(): features = clip_model.encode_image(image_tensor) features /= features.norm(dim=-1, keepdim=True) return features.cpu().numpy().flatten() except Exception as e: logger.error(f"Feature extraction error: {e}") raise def search_similar_items(features, top_k=10): """Search similar items with distance scores""" try: results = collection.query( query_embeddings=[features.tolist()], n_results=top_k, include=['metadatas', 'distances'] # distances 포함 ) similar_items = [] for metadata, distance in zip(results['metadatas'][0], results['distances'][0]): # 거리를 유사도 점수로 변환 (0~1 범위) similarity_score = 1 / (1 + distance) metadata['similarity_score'] = similarity_score # 메타데이터에 점수 추가 similar_items.append(metadata) return similar_items except Exception as e: logger.error(f"Search error: {e}") return [] def show_similar_items(similar_items): """Display similar items in a structured format with similarity scores""" st.subheader("Similar Items:") for item in similar_items: col1, col2 = st.columns([1, 2]) with col1: st.image(item['image_url']) with col2: # 유사도 점수를 퍼센트로 표시 similarity_percent = item['similarity_score'] * 100 st.write(f"Similarity: {similarity_percent:.1f}%") st.write(f"Brand: {item.get('brand', 'Unknown')}") st.write(f"Name: {item.get('name', 'Unknown')}") st.write(f"Price: {item.get('price', 'Unknown'):,}원") if 'discount' in item: st.write(f"Discount: {item['discount']}%") if 'original_price' in item: st.write(f"Original Price: {item['original_price']:,}원") # Initialize session state if 'image' not in st.session_state: st.session_state.image = None if 'detected_items' not in st.session_state: st.session_state.detected_items = None if 'selected_item_index' not in st.session_state: st.session_state.selected_item_index = None if 'upload_state' not in st.session_state: st.session_state.upload_state = 'initial' if 'search_clicked' not in st.session_state: st.session_state.search_clicked = False def reset_state(): """Reset all session state variables""" for key in list(st.session_state.keys()): del st.session_state[key] # Callback functions def handle_file_upload(): if st.session_state.uploaded_file is not None: image = Image.open(st.session_state.uploaded_file).convert('RGB') st.session_state.image = image st.session_state.upload_state = 'image_uploaded' st.rerun() def handle_detection(): if st.session_state.image is not None: detected_items = process_segmentation(st.session_state.image) st.session_state.detected_items = detected_items st.session_state.upload_state = 'items_detected' st.rerun() def handle_search(): st.session_state.search_clicked = True def main(): st.title("포어블랙 fashion demo!!!") # 파일 업로더 (upload_state가 initial일 때만 표시) if st.session_state.upload_state == 'initial': uploaded_file = st.file_uploader("Upload an image", type=['png', 'jpg', 'jpeg'], key='uploaded_file', on_change=handle_file_upload) # 이미지가 업로드된 상태 df if st.session_state.image is not None: st.image(st.session_state.image, caption="Uploaded Image", use_column_width=True) if st.session_state.detected_items is None: if st.button("Detect Items", key='detect_button', on_click=handle_detection): pass # 검출된 아이템 표시d if st.session_state.detected_items: # 감지된 아이템들d을 2열로 표시 cols = st.columns(2) for idx, item in enumerate(st.session_state.detected_items): with cols[idx % 2]: mask = item['mask'] masked_img = np.array(st.session_state.image) * np.expand_dims(mask, axis=2) st.image(masked_img.astype(np.uint8), caption=f"Detected {item['label']}") st.write(f"Item {idx + 1}: {item['label']}") st.write(f"Confidence: {item['score']*100:.1f}%") # 아이템 선택 selected_idx = st.selectbox( "Select item to search:", range(len(st.session_state.detected_items)), format_func=lambda i: f"{st.session_state.detected_items[i]['label']}", key='item_selector' ) st.session_state.selected_item_index = selected_idx # 유사 아이템 검색 col1, col2 = st.columns([1, 2]) with col1: search_button = st.button("Search Similar Items", key='search_button', on_click=handle_search, type="primary") # 강조된 버튼 with col2: num_results = st.slider("Number of results:", min_value=1, max_value=20, value=5, key='num_results') if st.session_state.search_clicked: with st.spinner("Searching similar items..."): try: selected_mask = st.session_state.detected_items[selected_idx]['mask'] features = extract_features(st.session_state.image, selected_mask) similar_items = search_similar_items(features, top_k=num_results) if similar_items: show_similar_items(similar_items) else: st.warning("No similar items found.") except Exception as e: st.error(f"Error during search: {str(e)}") # 새 검색 버튼 if st.button("Start New Search ", key='new_search'): reset_state() st.rerun() if __name__ == "__main__": main()