leedoming's picture
Create app.py
4531600 verified
raw
history blame
7.09 kB
import streamlit as st
import open_clip
import torch
import requests
from PIL import Image
from io import BytesIO
import time
import json
import numpy as np
import cv2
from inference_sdk import InferenceHTTPClient
import matplotlib.pyplot as plt
import base64
# Load model and tokenizer
@st.cache_resource
def load_model():
model, preprocess_val, tokenizer = open_clip.create_model_and_transforms('hf-hub:Marqo/marqo-fashionSigLIP')
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
return model, preprocess_val, tokenizer, device
model, preprocess_val, tokenizer, device = load_model()
# Load and process data
@st.cache_data
def load_data():
with open('musinsa-final.json', 'r', encoding='utf-8') as f:
return json.load(f)
data = load_data()
# Helper functions
@st.cache_data
def download_and_process_image(image_url):
try:
response = requests.get(image_url)
response.raise_for_status() # Raises an HTTPError for bad responses
image = Image.open(BytesIO(response.content))
# Convert image to RGB mode if it's in RGBA mode
if image.mode == 'RGBA':
image = image.convert('RGB')
return image
except requests.RequestException as e:
st.error(f"Error downloading image: {e}")
return None
except Exception as e:
st.error(f"Error processing image: {e}")
return None
def get_image_embedding(image):
image_tensor = preprocess_val(image).unsqueeze(0).to(device)
with torch.no_grad():
image_features = model.encode_image(image_tensor)
image_features /= image_features.norm(dim=-1, keepdim=True)
return image_features.cpu().numpy()
def setup_roboflow_client(api_key):
return InferenceHTTPClient(
api_url="https://outline.roboflow.com",
api_key=api_key
)
def segment_image(image_path, client):
try:
# 이미지 파일 읽기
with open(image_path, "rb") as image_file:
image_data = image_file.read()
# 이미지를 base64로 인코딩
encoded_image = base64.b64encode(image_data).decode('utf-8')
# 원본 이미지 로드
image = cv2.imread(image_path)
image = cv2.resize(image, (800, 600))
mask = np.zeros(image.shape, dtype=np.uint8)
# Roboflow API 호출
results = client.infer(encoded_image, model_id="closet/1")
# 결과가 이미 딕셔너리인 경우 JSON 파싱 단계 제거
if isinstance(results, dict):
predictions = results.get('predictions', [])
else:
# 문자열인 경우에만 JSON 파싱
predictions = json.loads(results).get('predictions', [])
if predictions:
for prediction in predictions:
points = prediction['points']
pts = np.array([[p['x'], p['y']] for p in points], np.int32)
scale_x = image.shape[1] / results['image']['width']
scale_y = image.shape[0] / results['image']['height']
pts = pts * [scale_x, scale_y]
pts = pts.astype(np.int32)
pts = pts.reshape((-1, 1, 2))
cv2.fillPoly(mask, [pts], color=(255, 255, 255)) # White mask
segmented_image = cv2.bitwise_and(image, mask)
else:
st.warning("No predictions found in the image. Returning original image.")
segmented_image = image
return Image.fromarray(cv2.cvtColor(segmented_image, cv2.COLOR_BGR2RGB))
except Exception as e:
st.error(f"Error in segmentation: {str(e)}")
# 원본 이미지를 다시 읽어 반환
return Image.open(image_path)
@st.cache_data
def process_database_cached(data):
database_embeddings = []
database_info = []
for item in data:
image_url = item['이미지 링크'][0]
product_id = item.get('\ufeff상품 ID') or item.get('상품 ID')
image = download_and_process_image(image_url)
if image is None:
continue
# Save the image temporarily
temp_path = f"temp_{product_id}.jpg"
image.save(temp_path, 'JPEG')
database_info.append({
'id': product_id,
'category': item['카테고리'],
'brand': item['브랜드명'],
'name': item['제품명'],
'price': item['정가'],
'discount': item['할인율'],
'image_url': image_url,
'temp_path': temp_path
})
return database_info
def process_database(client, data):
database_info = process_database_cached(data)
database_embeddings = []
for item in database_info:
segmented_image = segment_image(item['temp_path'], client)
embedding = get_image_embedding(segmented_image)
database_embeddings.append(embedding)
return np.vstack(database_embeddings), database_info
# Streamlit app
st.title("Fashion Search App with Segmentation")
# API Key input
api_key = st.text_input("Enter your Roboflow API Key", type="password")
if api_key:
CLIENT = setup_roboflow_client(api_key)
# Initialize database_embeddings and database_info
database_embeddings, database_info = process_database(CLIENT, data)
uploaded_file = st.file_uploader("Choose an image...", type="jpg")
if uploaded_file is not None:
image = Image.open(uploaded_file)
st.image(image, caption='Uploaded Image', use_column_width=True)
if st.button('Find Similar Items'):
with st.spinner('Processing...'):
# Save uploaded image temporarily
temp_path = "temp_upload.jpg"
image.save(temp_path)
# Segment the uploaded image
segmented_image = segment_image(temp_path, CLIENT)
st.image(segmented_image, caption='Segmented Image', use_column_width=True)
# Get embedding for segmented image
query_embedding = get_image_embedding(segmented_image)
similar_images = find_similar_images(query_embedding)
st.subheader("Similar Items:")
for img in similar_images:
col1, col2 = st.columns(2)
with col1:
st.image(img['info']['image_url'], use_column_width=True)
with col2:
st.write(f"Name: {img['info']['name']}")
st.write(f"Brand: {img['info']['brand']}")
st.write(f"Category: {img['info']['category']}")
st.write(f"Price: {img['info']['price']}")
st.write(f"Discount: {img['info']['discount']}%")
st.write(f"Similarity: {img['similarity']:.2f}")
else:
st.warning("Please enter your Roboflow API Key to use the app.")