import streamlit as st import open_clip import torch from PIL import Image import numpy as np from transformers import pipeline import chromadb import logging import io import requests from concurrent.futures import ThreadPoolExecutor # 로깅 설정 logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) # Initialize session state if 'image' not in st.session_state: st.session_state.image = None if 'detected_items' not in st.session_state: st.session_state.detected_items = None if 'selected_item_index' not in st.session_state: st.session_state.selected_item_index = None if 'upload_state' not in st.session_state: st.session_state.upload_state = 'initial' if 'search_clicked' not in st.session_state: st.session_state.search_clicked = False # Load models @st.cache_resource def load_models(): try: # CLIP 모델 model, _, preprocess_val = open_clip.create_model_and_transforms('hf-hub:Marqo/marqo-fashionSigLIP') # 세그멘테이션 모델 segmenter = pipeline(model="mattmdjaga/segformer_b2_clothes") device = torch.device("cuda" if torch.cuda.is_available() else "cpu") model.to(device) return model, preprocess_val, segmenter, device except Exception as e: logger.error(f"Error loading models: {e}") raise # 모델 로드 clip_model, preprocess_val, segmenter, device = load_models() # ChromaDB 설정 client = chromadb.PersistentClient(path="./clothesDB_11GmarketMusinsa") collection = client.get_collection(name="clothes") def process_segmentation(image): """Segmentation processing""" try: # pipeline 출력 결과 직접 처리 output = segmenter(image) if not output: logger.warning("No segments found in image") return None # 각 세그먼트의 마스크 크기 계산 segment_sizes = [np.sum(seg['mask']) for seg in output] if not segment_sizes: return None # 가장 큰 세그먼트 선택 largest_idx = np.argmax(segment_sizes) mask = output[largest_idx]['mask'] # 마스크가 numpy array가 아닌 경우 변환 if not isinstance(mask, np.ndarray): mask = np.array(mask) # 마스크가 2D가 아닌 경우 첫 번째 채널 사용 if len(mask.shape) > 2: mask = mask[:, :, 0] # bool 마스크를 float로 변환 mask = mask.astype(float) logger.info(f"Successfully created mask with shape {mask.shape}") return mask except Exception as e: logger.error(f"Segmentation error: {str(e)}") import traceback logger.error(traceback.format_exc()) return None def download_and_process_image(image_url, metadata_id): """Download image from URL and apply segmentation""" try: response = requests.get(image_url, timeout=10) # timeout 추가 if response.status_code != 200: logger.error(f"Failed to download image {metadata_id}: HTTP {response.status_code}") return None image = Image.open(io.BytesIO(response.content)).convert('RGB') logger.info(f"Successfully downloaded image {metadata_id}") mask = process_segmentation(image) if mask is not None: features = extract_features(image, mask) logger.info(f"Successfully extracted features for image {metadata_id}") return features logger.warning(f"No valid mask found for image {metadata_id}") return None except Exception as e: logger.error(f"Error processing image {metadata_id}: {str(e)}") import traceback logger.error(traceback.format_exc()) return None def update_db_with_segmentation(): """DB의 모든 이미지에 대해 segmentation을 적용하고 feature를 업데이트""" try: logger.info("Starting database update with segmentation") # 새로운 collection 생성 try: client.delete_collection("clothes_segmented") logger.info("Deleted existing segmented collection") except: logger.info("No existing segmented collection to delete") new_collection = client.create_collection( name="clothes_segmented", metadata={"description": "Clothes collection with segmentation-based features"} ) logger.info("Created new segmented collection") # 기존 collection에서 메타데이터만 가져오기 try: all_items = collection.get(include=['metadatas']) total_items = len(all_items['metadatas']) logger.info(f"Found {total_items} items in database") except Exception as e: logger.error(f"Error getting items from collection: {str(e)}") # 에러 발생 시 빈 리스트로 초기화 all_items = {'metadatas': []} total_items = 0 # 진행 상황 표시를 위한 progress bar progress_bar = st.progress(0) status_text = st.empty() successful_updates = 0 failed_updates = 0 with ThreadPoolExecutor(max_workers=4) as executor: futures = [] # 이미지 URL이 있는 항목만 처리 valid_items = [m for m in all_items['metadatas'] if 'image_url' in m] for metadata in valid_items: future = executor.submit( download_and_process_image, metadata['image_url'], metadata.get('id', 'unknown') ) futures.append((metadata, future)) # 결과 처리 및 새 DB에 저장 for idx, (metadata, future) in enumerate(futures): try: new_features = future.result() if new_features is not None: item_id = metadata.get('id', str(hash(metadata['image_url']))) try: new_collection.add( embeddings=[new_features.tolist()], metadatas=[metadata], ids=[item_id] ) successful_updates += 1 logger.info(f"Successfully added item {item_id}") except Exception as e: logger.error(f"Error adding item to new collection: {str(e)}") failed_updates += 1 else: failed_updates += 1 # 진행 상황 업데이트 progress = (idx + 1) / len(futures) progress_bar.progress(progress) status_text.text(f"Processing: {idx + 1}/{len(futures)} items. Success: {successful_updates}, Failed: {failed_updates}") except Exception as e: logger.error(f"Error processing item: {str(e)}") failed_updates += 1 continue # 최종 결과 표시 status_text.text(f"Update completed. Successfully processed: {successful_updates}, Failed: {failed_updates}") logger.info(f"Database update completed. Successful: {successful_updates}, Failed: {failed_updates}") # 성공적으로 처리된 항목이 있는지 확인 if successful_updates > 0: return True else: logger.error("No items were successfully processed") return False except Exception as e: logger.error(f"Database update error: {str(e)}") import traceback logger.error(traceback.format_exc()) return False def extract_features(image, mask=None): """Extract CLIP features with segmentation mask""" try: if mask is not None: img_array = np.array(image) mask = np.expand_dims(mask, axis=2) masked_img = img_array * mask masked_img[mask[:,:,0] == 0] = 255 # 배경을 흰색으로 image = Image.fromarray(masked_img.astype(np.uint8)) image_tensor = preprocess_val(image).unsqueeze(0).to(device) with torch.no_grad(): features = clip_model.encode_image(image_tensor) features /= features.norm(dim=-1, keepdim=True) return features.cpu().numpy().flatten() except Exception as e: logger.error(f"Feature extraction error: {e}") raise def search_similar_items(features, top_k=10): """Search similar items using segmentation-based features""" try: # 세그멘테이션이 적용된 collection이 있는지 확인 try: search_collection = client.get_collection("clothes_segmented") logger.info("Using segmented collection for search") except: # 없으면 기존 collection 사용 search_collection = collection logger.info("Using original collection for search") results = search_collection.query( query_embeddings=[features.tolist()], n_results=top_k, include=['metadatas', 'distances'] ) if not results or not results['metadatas'] or not results['distances']: logger.warning("No results returned from ChromaDB") return [] similar_items = [] for metadata, distance in zip(results['metadatas'][0], results['distances'][0]): try: similarity_score = 1 / (1 + float(distance)) item_data = metadata.copy() item_data['similarity_score'] = similarity_score similar_items.append(item_data) except Exception as e: logger.error(f"Error processing search result: {str(e)}") continue similar_items.sort(key=lambda x: x['similarity_score'], reverse=True) return similar_items except Exception as e: logger.error(f"Search error: {str(e)}") return [] def show_similar_items(similar_items): """Display similar items in a structured format with similarity scores""" if not similar_items: st.warning("No similar items found.") return st.subheader("Similar Items:") # 결과를 2열로 표시 items_per_row = 2 for i in range(0, len(similar_items), items_per_row): cols = st.columns(items_per_row) for j, col in enumerate(cols): if i + j < len(similar_items): item = similar_items[i + j] with col: try: if 'image_url' in item: st.image(item['image_url'], use_column_width=True) # 유사도 점수를 퍼센트로 표시 similarity_percent = item['similarity_score'] * 100 st.markdown(f"**Similarity: {similarity_percent:.1f}%**") st.write(f"Brand: {item.get('brand', 'Unknown')}") name = item.get('name', 'Unknown') if len(name) > 50: # 긴 이름은 줄임 name = name[:47] + "..." st.write(f"Name: {name}") # 가격 정보 표시 price = item.get('price', 0) if isinstance(price, (int, float)): st.write(f"Price: {price:,}원") else: st.write(f"Price: {price}") # 할인 정보가 있는 경우 if 'discount' in item and item['discount']: st.write(f"Discount: {item['discount']}%") if 'original_price' in item: st.write(f"Original: {item['original_price']:,}원") st.divider() # 구분선 추가 except Exception as e: logger.error(f"Error displaying item: {e}") st.error("Error displaying this item") def process_search(image, mask, num_results): """유사 아이템 검색 처리""" try: with st.spinner("Extracting features..."): features = extract_features(image, mask) with st.spinner("Finding similar items..."): similar_items = search_similar_items(features, top_k=num_results) return similar_items except Exception as e: logger.error(f"Search processing error: {e}") return None # Callback functions def handle_file_upload(): if st.session_state.uploaded_file is not None: image = Image.open(st.session_state.uploaded_file).convert('RGB') st.session_state.image = image st.session_state.upload_state = 'image_uploaded' st.rerun() def handle_detection(): if st.session_state.image is not None: detected_items = process_segmentation(st.session_state.image) st.session_state.detected_items = detected_items st.session_state.upload_state = 'items_detected' st.rerun() def handle_search(): st.session_state.search_clicked = True def admin_interface(): st.title("Admin Interface - DB Update") if st.button("Update DB with Segmentation"): with st.spinner("Updating database with segmentation... This may take a while..."): success = update_db_with_segmentation() if success: st.success("Database successfully updated with segmentation-based features!") else: st.error("Failed to update database. Please check the logs.") def main(): st.title("Fashion Search App") # Admin controls in sidebar st.sidebar.title("Admin Controls") if st.sidebar.checkbox("Show Admin Interface"): admin_interface() st.divider() # 파일 업로더 (upload_state가 initial일 때만 표시) if st.session_state.upload_state == 'initial': uploaded_file = st.file_uploader("Upload an image", type=['png', 'jpg', 'jpeg'], key='uploaded_file', on_change=handle_file_upload) # 이미지가 업로드된 상태 if st.session_state.image is not None: st.image(st.session_state.image, caption="Uploaded Image", use_column_width=True) if st.session_state.detected_items is None: if st.button("Detect Items", key='detect_button', on_click=handle_detection): pass # 검출된 아이템 표시 if st.session_state.detected_items: # 감지된 아이템들을 2열로 표시 cols = st.columns(2) for idx, item in enumerate(st.session_state.detected_items): with cols[idx % 2]: mask = item['mask'] masked_img = np.array(st.session_state.image) * np.expand_dims(mask, axis=2) st.image(masked_img.astype(np.uint8), caption=f"Detected {item['label']}") st.write(f"Item {idx + 1}: {item['label']}") st.write(f"Confidence: {item['score']*100:.1f}%") # 아이템 선택 selected_idx = st.selectbox( "Select item to search:", range(len(st.session_state.detected_items)), format_func=lambda i: f"{st.session_state.detected_items[i]['label']}", key='item_selector' ) # 검색 컨트롤 search_col1, search_col2 = st.columns([1, 2]) with search_col1: search_clicked = st.button("Search Similar Items", key='search_button', type="primary") with search_col2: num_results = st.slider("Number of results:", min_value=1, max_value=20, value=5, key='num_results') # 검색 결과 처리 if search_clicked or st.session_state.get('search_clicked', False): st.session_state.search_clicked = True selected_mask = st.session_state.detected_items[selected_idx]['mask'] # 검색 결과를 세션 상태에 저장 if 'search_results' not in st.session_state: similar_items = process_search(st.session_state.image, selected_mask, num_results) st.session_state.search_results = similar_items # 저장된 검색 결과 표시 if st.session_state.search_results: show_similar_items(st.session_state.search_results) else: st.warning("No similar items found.") # 새 검색 버튼 if st.button("Start New Search", key='new_search'): # 모든 상태 초기화 for key in list(st.session_state.keys()): del st.session_state[key] st.rerun() if __name__ == "__main__": main()