import streamlit as st import open_clip import torch from PIL import Image import numpy as np from transformers import pipeline import chromadb import logging import io import requests from concurrent.futures import ThreadPoolExecutor # 로깅 설정 logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) # Initialize session state if 'image' not in st.session_state: st.session_state.image = None if 'detected_items' not in st.session_state: st.session_state.detected_items = None if 'selected_item_index' not in st.session_state: st.session_state.selected_item_index = None if 'upload_state' not in st.session_state: st.session_state.upload_state = 'initial' if 'search_clicked' not in st.session_state: st.session_state.search_clicked = False # Load models @st.cache_resource def load_models(): try: # CLIP 모델 model, _, preprocess_val = open_clip.create_model_and_transforms('hf-hub:Marqo/marqo-fashionSigLIP') # 세그멘테이션 모델 segmenter = pipeline(model="mattmdjaga/segformer_b2_clothes") device = torch.device("cuda" if torch.cuda.is_available() else "cpu") model.to(device) return model, preprocess_val, segmenter, device except Exception as e: logger.error(f"Error loading models: {e}") raise # 모델 로드 clip_model, preprocess_val, segmenter, device = load_models() # ChromaDB 설정 client = chromadb.PersistentClient(path="./clothesDB_11GmarketMusinsa") collection = client.get_collection(name="clothes") def extract_color_histogram(image, mask=None): """Extract color histogram from the image, considering the mask if provided""" try: img_array = np.array(image) if mask is not None: # Reshape mask to match image dimensions mask = np.expand_dims(mask, axis=-1) # Add channel dimension img_array = img_array * mask # Broadcasting will work correctly now # Only consider pixels that are part of the clothing item valid_pixels = img_array[mask[:,:,0] > 0] else: valid_pixels = img_array.reshape(-1, 3) # Convert to HSV color space for better color representation if len(valid_pixels) > 0: # Reshape to proper dimensions for PIL Image valid_pixels = valid_pixels.reshape(-1, 3) img_hsv = Image.fromarray(valid_pixels.astype(np.uint8)).convert('HSV') hsv_pixels = np.array(img_hsv) # Calculate histogram for each HSV channel h_hist = np.histogram(hsv_pixels[:,0], bins=8, range=(0, 256))[0] s_hist = np.histogram(hsv_pixels[:,1], bins=8, range=(0, 256))[0] v_hist = np.histogram(hsv_pixels[:,2], bins=8, range=(0, 256))[0] # Normalize histograms h_hist = h_hist / (h_hist.sum() + 1e-8) # Add small epsilon to avoid division by zero s_hist = s_hist / (s_hist.sum() + 1e-8) v_hist = v_hist / (v_hist.sum() + 1e-8) return np.concatenate([h_hist, s_hist, v_hist]) return np.zeros(24) # 8bins * 3channels = 24 features except Exception as e: logger.error(f"Color histogram extraction error: {e}") return np.zeros(24) def process_segmentation(image): """Segmentation processing""" try: # pipeline 출력 결과 직접 처리 output = segmenter(image) if not output or len(output) == 0: logger.warning("No segments found in image") return [] processed_items = [] for segment in output: # 기본값을 포함하여 딕셔너리 생성 processed_segment = { 'label': segment.get('label', 'Unknown'), 'score': segment.get('score', 1.0), # score가 없으면 1.0을 기본값으로 사용 'mask': None } mask = segment.get('mask') if mask is not None: # 마스크가 numpy array가 아닌 경우 변환 if not isinstance(mask, np.ndarray): mask = np.array(mask) # 마스크가 2D가 아닌 경우 첫 번째 채널 사용 if len(mask.shape) > 2: mask = mask[:, :, 0] # bool 마스크를 float로 변환 processed_segment['mask'] = mask.astype(float) else: logger.warning(f"No mask found for segment with label {processed_segment['label']}") continue # 마스크가 없는 세그먼트는 건너뜀 processed_items.append(processed_segment) logger.info(f"Successfully processed {len(processed_items)} segments") return processed_items except Exception as e: logger.error(f"Segmentation error: {str(e)}") import traceback logger.error(traceback.format_exc()) return [] def extract_features(image, mask=None): """Extract both CLIP features and color features with segmentation mask""" try: # Extract CLIP features if mask is not None: img_array = np.array(image) mask = np.expand_dims(mask, axis=-1) masked_img = img_array * mask masked_img[mask[:,:,0] == 0] = 255 # Set background to white image = Image.fromarray(masked_img.astype(np.uint8)) image_tensor = preprocess_val(image).unsqueeze(0).to(device) with torch.no_grad(): clip_features = clip_model.encode_image(image_tensor) clip_features /= clip_features.norm(dim=-1, keepdim=True) clip_features = clip_features.cpu().numpy().flatten() # Extract color features color_features = extract_color_histogram(image, mask) # CLIP features are 768-dimensional, so we'll resize color features # to maintain the same total dimensionality clip_features = clip_features[:744] # Trim CLIP features to make room for color # Normalize features clip_features_normalized = clip_features / (np.linalg.norm(clip_features) + 1e-8) color_features_normalized = color_features / (np.linalg.norm(color_features) + 1e-8) # Adjust weights (total should be 768 to match collection dimensionality) clip_weight = 0.7 color_weight = 0.3 combined_features = np.zeros(768) # Initialize with zeros combined_features[:744] = clip_features_normalized * clip_weight # First 744 dimensions for CLIP combined_features[744:] = color_features_normalized * color_weight # Last 24 dimensions for color # Ensure final normalization combined_features = combined_features / (np.linalg.norm(combined_features) + 1e-8) return combined_features except Exception as e: logger.error(f"Feature extraction error: {e}") raise def download_and_process_image(image_url, metadata_id): """Download image from URL and apply segmentation""" try: response = requests.get(image_url, timeout=10) if response.status_code != 200: logger.error(f"Failed to download image {metadata_id}: HTTP {response.status_code}") return None image = Image.open(io.BytesIO(response.content)).convert('RGB') logger.info(f"Successfully downloaded image {metadata_id}") processed_items = process_segmentation(image) if processed_items and len(processed_items) > 0: # 가장 큰 세그먼트의 마스크 사용 largest_mask = max(processed_items, key=lambda x: np.sum(x['mask']))['mask'] features = extract_features(image, largest_mask) logger.info(f"Successfully extracted features for image {metadata_id}") return features logger.warning(f"No valid mask found for image {metadata_id}") return None except Exception as e: logger.error(f"Error processing image {metadata_id}: {str(e)}") import traceback logger.error(traceback.format_exc()) return None def update_db_with_segmentation(): """DB의 모든 이미지에 대해 segmentation을 적용하고 feature를 업데이트""" try: logger.info("Starting database update with segmentation and color features") # 새로운 collection 생성 try: client.delete_collection("clothes_segmented") logger.info("Deleted existing segmented collection") except: logger.info("No existing segmented collection to delete") new_collection = client.create_collection( name="clothes_segmented", metadata={"description": "Clothes collection with segmentation and color features"} ) logger.info("Created new segmented collection") # 기존 collection에서 메타데이터만 가져오기 try: all_items = collection.get(include=['metadatas']) total_items = len(all_items['metadatas']) logger.info(f"Found {total_items} items in database") except Exception as e: logger.error(f"Error getting items from collection: {str(e)}") all_items = {'metadatas': []} total_items = 0 # 진행 상황 표시를 위한 progress bar progress_bar = st.progress(0) status_text = st.empty() successful_updates = 0 failed_updates = 0 with ThreadPoolExecutor(max_workers=4) as executor: futures = [] # 이미지 URL이 있는 항목만 처리 valid_items = [m for m in all_items['metadatas'] if 'image_url' in m] for metadata in valid_items: future = executor.submit( download_and_process_image, metadata['image_url'], metadata.get('id', 'unknown') ) futures.append((metadata, future)) # 결과 처리 및 새 DB에 저장 for idx, (metadata, future) in enumerate(futures): try: new_features = future.result() if new_features is not None: item_id = metadata.get('id', str(hash(metadata['image_url']))) try: new_collection.add( embeddings=[new_features.tolist()], metadatas=[metadata], ids=[item_id] ) successful_updates += 1 logger.info(f"Successfully added item {item_id}") except Exception as e: logger.error(f"Error adding item to new collection: {str(e)}") failed_updates += 1 else: failed_updates += 1 # 진행 상황 업데이트 progress = (idx + 1) / len(futures) progress_bar.progress(progress) status_text.text(f"Processing: {idx + 1}/{len(futures)} items. Success: {successful_updates}, Failed: {failed_updates}") except Exception as e: logger.error(f"Error processing item: {str(e)}") failed_updates += 1 continue # 최종 결과 표시 status_text.text(f"Update completed. Successfully processed: {successful_updates}, Failed: {failed_updates}") logger.info(f"Database update completed. Successful: {successful_updates}, Failed: {failed_updates}") # 성공적으로 처리된 항목이 있는지 확인 if successful_updates > 0: return True else: logger.error("No items were successfully processed") return False except Exception as e: logger.error(f"Database update error: {str(e)}") import traceback logger.error(traceback.format_exc()) return False def search_similar_items(features, top_k=10): """Search similar items using combined features""" try: # 세그멘테이션이 적용된 collection이 있는지 확인 try: search_collection = client.get_collection("clothes_segmented") logger.info("Using segmented collection for search") except: # 없으면 기존 collection 사용 search_collection = collection logger.info("Using original collection for search") results = search_collection.query( query_embeddings=[features.tolist()], n_results=top_k, include=['metadatas', 'distances'] ) if not results or not results['metadatas'] or not results['distances']: logger.warning("No results returned from ChromaDB") return [] similar_items = [] for metadata, distance in zip(results['metadatas'][0], results['distances'][0]): try: similarity_score = distance item_data = metadata.copy() item_data['similarity_score'] = similarity_score similar_items.append(item_data) except Exception as e: logger.error(f"Error processing search result: {str(e)}") continue similar_items.sort(key=lambda x: x['similarity_score'], reverse=True) return similar_items except Exception as e: logger.error(f"Search error: {str(e)}") return [] def show_similar_items(similar_items): """Display similar items in a structured format with similarity scores""" if not similar_items: st.warning("No similar items found.") return st.subheader("Similar Items:") # 결과를 2열로 표시 items_per_row = 2 for i in range(0, len(similar_items), items_per_row): cols = st.columns(items_per_row) for j, col in enumerate(cols): if i + j < len(similar_items): item = similar_items[i + j] with col: try: if 'image_url' in item: st.image(item['image_url'], use_column_width=True) # 유사도 점수를 퍼센트로 표시 similarity_percent = item['similarity_score'] st.markdown(f"**Similarity: {similarity_percent:.1f}%**") st.write(f"Brand: {item.get('brand', 'Unknown')}") name = item.get('name', 'Unknown') if len(name) > 50: # 긴 이름은 줄임 name = name[:47] + "..." st.write(f"Name: {name}") # 가격 정보 표시 price = item.get('price', 0) if isinstance(price, (int, float)): st.write(f"Price: {price:,}원") else: st.write(f"Price: {price}") # 할인 정보가 있는 경우 if 'discount' in item and item['discount']: st.write(f"Discount: {item['discount']}%") if 'original_price' in item: st.write(f"Original: {item['original_price']:,}원") st.divider() # 구분선 추가 except Exception as e: logger.error(f"Error displaying item: {e}") st.error("Error displaying this item") def process_search(image, mask, num_results): """유사 아이템 검색 처리""" try: with st.spinner("Extracting features..."): features = extract_features(image, mask) with st.spinner("Finding similar items..."): similar_items = search_similar_items(features, top_k=num_results) return similar_items except Exception as e: logger.error(f"Search processing error: {e}") return None def handle_file_upload(): if st.session_state.uploaded_file is not None: image = Image.open(st.session_state.uploaded_file).convert('RGB') st.session_state.image = image st.session_state.upload_state = 'image_uploaded' st.rerun() def handle_detection(): if st.session_state.image is not None: detected_items = process_segmentation(st.session_state.image) st.session_state.detected_items = detected_items st.session_state.upload_state = 'items_detected' st.rerun() def handle_search(): st.session_state.search_clicked = True def main(): st.title("Fashion Search App") # Admin controls in sidebar st.sidebar.title("Admin Controls") if st.sidebar.checkbox("Show Admin Interface"): # Admin interface 구현 (필요한 경우) st.sidebar.warning("Admin interface is not implemented yet.") st.divider() # 파일 업로더 if st.session_state.upload_state == 'initial': uploaded_file = st.file_uploader("Upload an image", type=['png', 'jpg', 'jpeg'], key='uploaded_file', on_change=handle_file_upload) # 이미지가 업로드된 상태 if st.session_state.image is not None: st.image(st.session_state.image, caption="Uploaded Image", use_column_width=True) if st.session_state.detected_items is None: if st.button("Detect Items", key='detect_button', on_click=handle_detection): pass # 검출된 아이템 표시 if st.session_state.detected_items is not None and len(st.session_state.detected_items) > 0: # 감지된 아이템들을 2열로 표시 cols = st.columns(2) for idx, item in enumerate(st.session_state.detected_items): with cols[idx % 2]: try: if item.get('mask') is not None: masked_img = np.array(st.session_state.image) * np.expand_dims(item['mask'], axis=2) st.image(masked_img.astype(np.uint8), caption=f"Detected {item.get('label', 'Unknown')}") st.write(f"Item {idx + 1}: {item.get('label', 'Unknown')}") # score 값이 있고 숫자인 경우에만 표시 score = item.get('score') if score is not None and isinstance(score, (int, float)): st.write(f"Confidence: {score*100:.1f}%") else: st.write("Confidence: N/A") except Exception as e: logger.error(f"Error displaying item {idx}: {str(e)}") st.error(f"Error displaying item {idx}") valid_items = [i for i in range(len(st.session_state.detected_items)) if st.session_state.detected_items[i].get('mask') is not None] if not valid_items: st.warning("No valid items detected for search.") return # 아이템 선택 selected_idx = st.selectbox( "Select item to search:", valid_items, format_func=lambda i: f"{st.session_state.detected_items[i].get('label', 'Unknown')}", key='item_selector' ) # 검색 컨트롤 search_col1, search_col2 = st.columns([1, 2]) with search_col1: search_clicked = st.button("Search Similar Items", key='search_button', type="primary") with search_col2: num_results = st.slider("Number of results:", min_value=1, max_value=20, value=5, key='num_results') # 검색 결과 처리 if search_clicked or st.session_state.get('search_clicked', False): st.session_state.search_clicked = True selected_item = st.session_state.detected_items[selected_idx] if selected_item.get('mask') is None: st.error("Selected item has no valid mask for search.") return # 검색 결과를 세션 상태에 저장 if 'search_results' not in st.session_state: similar_items = process_search(st.session_state.image, selected_item['mask'], num_results) st.session_state.search_results = similar_items # 저장된 검색 결과 표시 if st.session_state.search_results: show_similar_items(st.session_state.search_results) else: st.warning("No similar items found.") # 새 검색 버튼 if st.button("Start New Search", key='new_search'): # 모든 상태 초기화 for key in list(st.session_state.keys()): del st.session_state[key] st.rerun() if __name__ == "__main__": main()